By: Philip Zucker
Re-posted from: https://www.philipzucker.com/egraph-2/
Last time we spent a bit of effort building an e-graph data structure in Julia. The e-graph is a data structure that compactly stores and maintains equivalence classes of terms.
We can use the data structure for equational rewriting by iteratively pattern matching terms inside of it and adding new equalities found by instantiating those patterns. For example, we could use the pattern ~x + 0
to instantiate the rule ~x + 0 == ~x
, find that ~x
binds to foo23(a,b,c)
in the e-graph and then insert into the e-graph the equality foo23(a,b,c) + 0 == foo23(a,b,c)
and do this over and over. The advantage of the e-graph vs just rewriting terms is that we never lose information while doing this or rewrite ourselves into a corner. The disadvantage is that we maintain this huge data structure.
The rewrite rule problem can be viewed something like a graph. The vertices of the graph are every possible term. A rewrite rule that is applicable to that node describes a directed edge in the graph.
This graph is very large, infinite often, so it can only be described implicitly.
This graph perspective fails to capture some useful properties though. Treating each term as an indivisible vertex fails the capture that there can be rewriting happening in only a small part of the term, the vast majority of it left unchanged. There is a lot of opportunity for shared computations.
The EGraph from this perspective is a data structure for holding the already seen set of vertices efficiently and sharing these computations.
You can use this procedure for theorem proving. Insert the two terms you want to prove equal into the e-graph. Iteratively rewrite using your equalities. At each iteration, check whether two nodes in the e-graph have become equivalent. If so, congrats, problem proved! This is the analog of finding a path between two nodes in a graph by doing a bidirectional search from the start and endpoint. This is roughly the main method by which Z3 reasons about syntactic terms under equality in the presence of quantified equations. There are also methods by which to reconstruct the proof found if you want it.
Another application of basically the same algorithm is finding optimized rewrites. Apply rewrites until you’re sick of it, then extract from the equivalence class of your query term your favorite equivalent term. This is a very intriguing application to me as it produces more than just a yes/no answer.
Pattern Matching
Pattern matching takes a pattern with named holes in it and tries to find a way to fill the holes to match a term. The package SymbolicUtils.jl is useful for pattern matching against single terms, which is what I’ll describe in this section, but not quite set up to pattern match in an e-graph.
Pattern matching against a term is a very straightforward process once you sit down to do it. Where it gets complicated is trying to be efficient. But one should walk before jogging. I sometimes find myself so overwhelmed by performance considerations that I never get started.
First, one needs a data type for describing patterns (well, one doesn’t need it, but it’s nice. You could directly use pattern matching combinators).
struct PatVar
id::Symbol
end
struct PatTerm
head::Symbol
args::Array{Union{PatTerm, PatVar}}
end
Pattern = Union{PatTerm,PatVar}
There is an analogous data structure in SymbolicUtils which I probably should’ve just used. For example Slot
is basically PatVar
.
The pattern matching function takes a Term
and Pattern
and returns a dictionary Dict{PatVar,Term}
of bindings of pattern variables to terms or nothing
if it fails
Matching a pattern variable builds a dictionary of bindings from that variable to a term.
match(t::Term, p::PatVar ) = Dict(p => t)
Matching a PatTerm
requires we check if anything is obviously wrong (wrong head, or wrong number of args) and then recurse on the args
.
The resulting binding dictionaries are checked for consistency of their bindings and merged.
function merge_consistent(ds)
newd = Dict()
for d in ds
for (k,v) in d
if haskey(newd,k)
if newd[k] != v
return nothing
end
else
newd[k] = v
end
end
end
return newd
end
function match(t::Term, p::PatTerm)
if t.head != p.head || length(t.args) != length(p.args)
return nothing
else
merge_consistent( [ match(t1,p1) for (t1,p1) in zip(t.args, p.args) ])
end
end
There are other ways of arranging this computation, such as passing a dictionary parameter along and modifying it, by I find the above purely functional definition the most clear.
Pattern Matching in E-Graphs
The twist that E-Graphs throw into the pattern matching process is that instead of checking whether a child of a term matches the pattern, we need to check for any possible child matching in the equivalence class of children. This means our pattern matching procedure contains a backtracking search.
The de Moura and Bjorner e-matching paper describes how to do this efficiently via an interpreted virtual machine. An explicit virtual machine is important for them because the patterns change during the Z3 search process, but it seems less relevant for the use case here or in egg. It may be better to use partial evaluation techniques to remove the existence of the virtual machine at runtime like in A Typed, Algebraic Approach to Parsing but I haven’t figured out how to do it yet.
The Simplify paper describes a much simpler inefficient pattern matching algorithm in terms of generators. This can be directly translated to Julia using Channels. This is a nice blog post describing how to build generators in Julia that work like python generators. Basically, as soon as you define your generator function, wrap the whole thing in a Channel() do c
and when you want to yield myvalue
call put!(c, myvalue)
.
# https://www.hpl.hp.com/techreports/2003/HPL-2003-148.pdf
function ematchlist(e::EGraph, t::Array{Union{PatTerm, PatVar}} , v::Array{Id}, sub)
Channel() do c
if length(t) == 0
put!(c, sub)
else
for sub1 in ematch(e, t[1], v[1], sub)
for sub2 in ematchlist(e, t[2:end], v[2:end], sub1)
put!(c, sub2)
end
end
end
end
end
# sub should be a map from pattern variables to Id
function ematch(e::EGraph, t::PatVar, v::Id, sub)
Channel() do c
if haskey(sub, t)
if find_root!(e, sub[t]) == find_root!(e, v)
put!(c, sub)
end
else
put!(c, Base.ImmutableDict(sub, t => find_root!(e, v)))
end
end
end
function ematch(e::EGraph, t::PatTerm, v::Id, sub)
Channel() do c
for n in e.classes[find_root!(e,v)].nodes
if n.head == t.head
for sub1 in ematchlist(e, t.args , n.args , sub)
put!(c,sub1)
end
end
end
end
end
You can then instantiate patterns with the returned dictionaries via
function instantiate(e::EGraph, p::PatVar , sub)
sub[p]
end
function instantiate(e::EGraph, p::PatTerm , sub)
push!( e, Term(p.head, [ instantiate(e,a,sub) for a in p.args ] ))
end
And build rewrite rules
struct Rule
lhs::Pattern
rhs::Pattern
end
function rewrite!(e::EGraph, r::Rule)
matches = []
EMPTY_DICT2 = Base.ImmutableDict{PatVar, Id}(PatVar(:____), Id(-1))
for (n, cls) in e.classes
for sub in ematch(e, r.lhs, n, EMPTY_DICT2)
push!( matches, ( instantiate(e, r.lhs ,sub) , instantiate(e, r.rhs ,sub)))
end
end
for (l,r) in matches
union!(e,l,r)
end
rebuild!(e)
end
Here’s a very simple equation proving function that takes in a pile of rules
function prove_eq(e::EGraph, t1::Id, t2::Id , rules)
for i in 1:3
if in_same_set(e,t1,t2)
return true
end
for r in rules
rewrite!(e,r) # I should split this stuff up. We only need to rebuild at the end
end
end
return nothing
end
As sometimes happens, I’m losing steam on this and would like to work on something different for a bit. But I loaded up the WIP at https://github.com/philzook58/EGraphs.jl.
Bits and Bobbles
Catlab
You can find piles of equational axioms in Catlab like here for example
A tricky thing is that some equational axioms of categories seem to produce variables out of thin air.
Consider the rewrite rule ~f => id(~A) ⋅ ~f
. Where does the A
come from? It’s because we’ve highly suppressed all the typing information. The A
should come from the type of f
.
I think the simplest way to fix is the “type tag” approach I mentioned here https://www.philipzucker.com/theorem-proving-for-catlab-2-lets-try-z3-this-time-nope/ and can be read about here https://people.mpi-inf.mpg.de/~jblanche/mono-trans.pdf. You wrap every term in a tagging function so that f
becomes tag(f, Hom(a,b))
. This approach makes sense because it turns GATs purely equational without the implicit typing context, I think. It is a bummer that it will probably choke the e-graph with junk though.
It may be best to build a TypedTerm
to use rather than Term
that has an explicit field for the type. This brings it close to the representation that Catlab already uses for syntax, so maybe I should just directly use that. I like having control over my own type definitions though and Catlab plays some monkey business with the Julia level types. 🙁
struct TypedTerm
type
head
args
end
As I have written it so far, I don’t allow the patterns to match on the heads of terms, although there isn’t any technical reason to not allow it. The natural way of using a TypedTerm would probably want to do so though, since you’d want to match on the type sometimes while keeping the head a pattern. Another possible way to do this that avoids all these issues is to actually make terms a bit like how regular Julia Expr are made, which usually has :call in the head position. Then by convention the first arg could be the actual head, the second arg the type, and the rest of the arguments the regular arguments.
A confusing example sketch:
TypedTerm(:mcopy, typedterm(Hom(a,otimes(a,a))) , [TypedTerm(a,Ob,[])])
would become
Term(:call, [:mcopy, term!(Hom(a,otimes(a,a)) , term!(a) ])
Catlab already does some simplifications on it’s syntax, like collecting up associative operations into lists. It is confusing how to integrate the with the e-graph structure so I think the first step is to just not. That stuff isn’t on by default, so you can get rid of it by defining your own versions using @syntax
using Catlab
using Catlab.Theories
import Catlab.Theories: compose, Ob, Hom
@syntax MyFreeCartesianCategory{ObExpr,HomExpr} CartesianCategory begin
#otimes(A::Ob, B::Ob) = associate_unit(new(A,B), munit)
#otimes(f::Hom, g::Hom) = associate(new(f,g))
#compose(f::Hom, g::Hom) = associate(new(f,g; strict=true))
#pair(f::Hom, g::Hom) = compose(mcopy(dom(f)), otimes(f,g))
#proj1(A::Ob, B::Ob) = otimes(id(A), delete(B))
#proj2(A::Ob, B::Ob) = otimes(delete(A), id(B))
end
# we can translate the axiom sets programmatically.
names(Main.MyFreeCartesianCategory)
#A = Ob(MyFreeCartesianCategory, :A)
A = Ob(MyFreeCartesianCategory, :A)
B = Ob(MyFreeCartesianCategory, :B)
f = Hom(:f, A,B)
g = Hom(:g, B, A)
h = compose(compose(f,g),f)
# dump(h)
dump(f)
#=
Main.MyFreeCartesianCategory.Hom{:generator}
args: Array{Any}((3,))
1: Symbol f
2: Main.MyFreeCartesianCategory.Ob{:generator}
args: Array{Symbol}((1,))
1: Symbol A
type_args: Array{GATExpr}((0,))
3: Main.MyFreeCartesianCategory.Ob{:generator}
args: Array{Symbol}((1,))
1: Symbol B
type_args: Array{GATExpr}((0,))
type_args: Array{GATExpr}((2,))
1: Main.MyFreeCartesianCategory.Ob{:generator}
args: Array{Symbol}((1,))
1: Symbol A
type_args: Array{GATExpr}((0,))
2: Main.MyFreeCartesianCategory.Ob{:generator}
args: Array{Symbol}((1,))
1: Symbol B
type_args: Array{GATExpr}((0,))
=#
Rando thought: Can an EGraph data structure itself be expressed in Catlab? I note that IntDisjointSet does show up in a Catlab definition of pushouts. The difficulty and scariness of the EGraph is maintaining it’s invariants, which may be expressible as some kind of composition condition on the various maps that make up the EGraph.
Comments on Equation Search
The rewrite rule problem can be viewed something like a graph. The vertices of the graph is every possible term. A rewrite rule that is applicable to that node describes a directed edge in the graph.
This graph is very large, infinite often, so it can only be described implicitly.
Proving the equivalence of two terms is like finding a path in this graph. You could attempt a greedy approach or a breadth first search approach, or any variety of your favorite search algorithm like A*.
This graph perspective fails to capture some useful properties though. Treating each term as an indivisible vertex fails the capture that there can be rewriting happening in only a small part of the term, the vast majority of it left unchanged. There is a lot of opportunity for shared computations.
The EGraph from this perspective is a data structure for holding the already seen vertices.
I suspect some heuristics to be helpful like applying the “simplifying” direction of the equations more often than the “complicating” direction. In a 5:1 ratio let’s say.
A natural algorithm to consider for optimizing terms is to take the best expression found so far, destroy the e-graph and place just that expression in it.
Try rewrite rules. If the best is still old query, don’t destroy e-graph and apply a new round of rewrites the widen the radius of the e-graph. Gives you kind of a combo greedy + complete.
Partial Evaluation
Partial evaluation seems like a good technique for optimizing pattern matches. The ideal pattern match code expands out to a nicely nested set of if statements. One alternative to passing a dictionary might be to assign each pattern variable to an integer at compile time and instead pass an array, which would be a bit better. However, by using metaprogramming we can insert local variables into generated code and avoid the need for a dictionary being based around at runtime. Then the julia compiler can register allocate and what not like it usually does (and quite efficiently I’d expect).
See this post (in particular to reference links https://www.philipzucker.com/metaocaml-style-partial-evaluation-in-coq/ for more on partial evaluation.
A first thing to consider is that we’re building up and destroying dictionaries at a shocking rate.
Secondly dictionaries themselves are (relatively) expensive things.
I experimented a bit with using a curried form for match
to see if maybe the Julia compiler was smart enough to sort of evaluate away my use of dictionaries, but that did not seem to be the case.
I found the examining the @code_llvm
and @code_native
of the following simple experiments illuminating as to what Julia can and can no get rid of when the dictionary is obviously known at compile time to human eyes. Mostly it needs to be very obvious. I suspect Julia does not have any built special compiler magic reasoning for dictionaries and trying to automatically infer what’s safely optimizable by actually expanding the definition of a dictionary op sounds impossible.
function foo() #bad
x = Dict(:a => 4)
return x[:a]
end
function foo2()
return 4
end
function foo3() # bad
x = Dict(:a => 4)
() -> x[:a]
end
function foo4() # much better. But we do pay a closure cost.
x = Dict(:a => 4)
r = x[:a]
() -> r
end
function foo5()
x = Dict(:a => 4)
:( () -> $(x[:a]) )
end
function foo7()
return foo()
end
function foo6() # julia does however do constant propagation and deletion of unnecessary bullcrap
x = 4
y = 7
y += x
return x
end
function foo7() # this compiles well
x = Base.ImmutableDict( :a => 4)
return x[:a]
end
function foo8()
x = Base.ImmutableDict( :a => 4)
r = x[:a]
z -> z == r # still has a closure indirection. It's pretty good though
end
function foo9()
x = Base.ImmutableDict( :a => 4)
r = x[:a]
z -> z == 4
end
#@code_native foo7()
z = foo4()
#@code_llvm z()
z = eval(foo5())
@code_native z()
@code_native foo2()
@code_llvm foo6()
z = foo8()
@code_native z(3)
z = foo9()
@code_native z(3)
@code_native foo7()
# so it's possible that using an ImmutableDict is sufficient for julia to inline almost evertything itself.
# You want the indexing to happen outside the closure.
A possibly useful technique is partial evaluation. We have in a sense built an interpreter for the pattern language. Specializing this interpeter gives a fast pattern matcher. We explicitly can build up the code Expr that will do the pattern matching by interpreting the pattern.
So here’s a version that takes in a pattern and builds code the perform the pattern in terms of if then else statements and runtime bindings.
Here’s a sketch. This version only returns true or false rather than returning the bindings. I probably need to cps it to get actual bindings out. I think that’s what i was starting to do with the c parameter.
function prematch(p::PatVar, env, t, c )
if haskey(env, p)
:( $(env[p]) != $t && return false )
else
env[p] = gensym(p.id) #make fresh variable
:( $(env[p]) = $t ) #bind it at runtime to the current t
end
end
function sequence(code)
foldl(
(s1,s2) -> quote
$s1
$s2
end
, code)
end
function prematch(p::PatTerm, env, t, c)
if length(p.args) == 0
:( $(t).head == $( QuoteNode(p.head) ) && length($t) == 0 && return false )
else
quote
if $(t).head != $( QuoteNode(p.head) ) || $( length(p.args) ) != length($(t).args)
return false
else
$( sequence([prematch(a, env, :( $(t).args[$n] ), c) for (n,a) in enumerate(p.args) ]))
end
end
end
end
println( prematch( PatTerm(:f, []) , Dict(), :w , nothing) )
println( prematch( PatTerm(:f, [PatVar(:a)]) , Dict(), :t , nothing) )
println( prematch( PatTerm(:f, [PatVar(:a), PatVar(:a)]) , Dict(), :t , nothing) )