Tag Archives: datastructures

Partial Evaluation of a Pattern Matcher for E-graphs

By: Philip Zucker

Re-posted from: https:/www.philipzucker.com/staging-patterns/

Partial Evaluation is cool.

There are different interpretations of what partial evaluation means. I think the term can be applied to anything that you’re, well, partially evaluating. This might include really reaching down into and ripping apart the AST of a code block finding things like 1 + 7 and replacing them with 8.
However, my mental model of partial evaluation is a function that takes 2 arguments, one given at compile time and one at run time. Given the argument at compile time, you can hopefully do all sorts of decisions and control flow choices based off that information.
What I know of this, I’ve learned from reading Oleg Kiselyov http://okmij.org/ftp/meta-programming/index.html and his and other’s work in MetaOcaml. Sometimes this technique is called staged metaprogramming, or generative metaprogramming because you don’t really peek inside of Expr.

I wrote an implementation of a pattern matching in EGraphs that is slow. https://www.philipzucker.com/egraph-2/ I think partial evaluation in this style is a nice way to speed it up.

A useful pattern for organizing programming tasks is to build tiny programming languages (DSLs) and then write a function that takes a data structure containing a program written in this tiny language and interprets them over inputs. It is very often the case that the actual form of this program one is going to interpret is known at compile time.

A reasonable problem solving strategy is

  1. Try some maximally concrete examples and see what hand written code you would write
  2. Write an interpreter over some DSL describing the class of programs
  3. Stage the interpreter with quasiquotation, rearranging minimally as necessary.
  4. Check that your examples produce the same code or good enough compared to your hand written code.

So let’s start with 1.

Pattern matching feels a little scary, but when you actually sit down to do it on a concrete pattern, it ends up being a straightforward and boring pile of if-then-else clauses.

Here’s some straightforward examples of regular pattern matching

# pattern f(X)
function matchfx(e)
    if e isa Expr
        if e.head == :f and length(e.args) == 1
            return e.args[1]
        end
    end
end

# pattern f(X,Y)
function matchfxy(e)
    if e isa Expr
        if e.head == :f and length(e.args) == 2
            return [:X => e.args[1], :Y => e.args[2]]
        end
    end
end

EMatching throws a wrench into the simplicity that you need to return multiple matches and traverse a search through the multiple enodes per eclass. Nevertheless, for concrete patterns, it isn’t really that complicated either.

#pattern f(X)
function matchfx(G, e::Int64)
    buf = []
    for enode in G[e]
        if f.head == :f && length(f.args) == 1
            push!(buf, f.args[1])
        end
    end
    return buf
end

#pattern f(X,Y) returns 2 variable bindings
function matchfxy(G, e::Int64)
buf = []
for enode in G[e]
    if f.head == :f && length(f.args) == 2
        push!(buf, [f.arg[1], f.arg[2]])
    end
end
return buf
end

#pattern f(X,X) requires a check 
function match(G, e::Int64)
    buf = []
    for enode in G[e]
        if f.head == :f && f.args[1] == f.args[2]
            push!(buf, f.args[1])
        end
    end
    return buf
end

#f(g(X)) requires deeper traversal. # of loops is proprtional to number of eclasses expanded
function matchfgx(G, e::Int64)
    buf = []
    for enode in G[e]
        if enode.head == :f and length(f.args) == 1
            for enode in G[f.args[1]]
                if enode.head == :g and length(enode.args) == 1
                    push!(buf, enode.args[0])
                end
            end
        end
    end
end

Ok, well how can we generate these?

An Interpreter

I just so happen to know from the egg https://egraphs-good.github.io/ paper and the de Moura, Bjorner paper http://leodemoura.github.io/files/ematching.pdf that it is smart to compile patterns into a sequence of instructions. This is also similar to what is done in Prolog with the Warren Abstract Machine. Sequences of instructions clarify the control flow compared to directly writing recursive functions to interpret patterns. This explicit control flow is subtly important for the ability to stage programs.

For what I’ve shown above we need 3 instructions

  1. Opening up an eclass. This makes a for loop
  2. Checking if two variables are the same like in f(X,X)
  3. Yielding a binding of variables

Here is a simple data type to describe patterns and a helper macro to construct them. It considers capital symbols to be variables.

@auto_hash_equals struct Pattern
    head::Symbol
    args
end

struct PatVar
    name::Symbol
end

pat(e::Symbol) = isuppercase(String(e)[1]) ? PatVar(e) : Pattern(e, [])
pat(e::Expr) = Pattern(e.args[1], pat.(e.args[2:end]) )
macro pat(e) 
    return pat(e)
end
# example usage: @pat f(g(X), X)

And here are the data types Bind, CheckClassEq, Yield of the instructions of my language corresponding to 1-3 above.

@auto_hash_equals struct ENodePat
    head::Symbol
    args::Vector{Symbol}
end

@auto_hash_equals struct Bind 
    eclass::Symbol
    enodepat::ENodePat
end

@auto_hash_equals struct CheckClassEq
    eclass1::Symbol
    eclass2::Symbol
end

@auto_hash_equals struct Yield
    yields::Vector{Symbol}
end

Here is a function that takes a Pattern and turns it into a sequence of instructions by traversing it depth first. This may not be the most optimal ordering to produce. During the pass I maintain a context from which you can lookup variables from the Pattern and find their new names that I bound in the Insns.

function compile_pat(eclass, p::Pattern,ctx)
    a = [gensym() for i in 1:length(p.args) ]
    return vcat(  Bind(eclass, ENodePat(p.head, a  )) , [compile_pat(eclass, p2, ctx) for (eclass , p2) in zip(a, p.args)]...)
end

function compile_pat(eclass, p::PatVar,ctx)
    if haskey(ctx, p.name)
        return CheckClassEq(eclass, ctx[p.name])
    else
        ctx[p.name] = eclass
        return []
    end
end
    
function compile_pat(p)
    ctx = Dict()
    insns = compile_pat(:start, p, ctx)
    
    return vcat(insns, Yield( [values(ctx)...]) ), ctx
end

Then I can write an interpreter over these instructions. You case on which instruction type. If it is a Bind you make a for loop over all ENodes in that eclass, check if the head matches the pattern’s head and the length is right, add all the new found enode numbers to the context, and finally recurse into the next instruction. If it is a CheckClassEq you check with the context to see if we should stop this branch of the search. If it is a Yield we add the current solution to the buf. The buf thing came up in Alessandro’s journey to optimize the ematcher in Metatheory.jl https://github.com/0x0f0f0f/Metatheory.jl and it really cleared things up for me compared to my monstrous Channel based implementation.

function interp_unstaged(G, insns, ctx, buf) 
    if length(insns) == 0
        return 
    end
    insn = insns[1]
    insns = insns[2:end]
    if insn isa Bind
        for enode in G[ctx[insn.eclass]] 
            if enode.head == insn.enodepat.head && length(enode.args) == length(insn.enodepat.args)
                        for (n,v) in enumerate(insn.enodepat.args)
                            ctx[v] = enode.args[n]
                        end
                        interp_unstaged(G,  insns, ctx, buf)
                end
            end
        end
    elseif insn isa Yield
        quote
            push!( buf, [ctx[y] for y in insn.yields])
        end
    elseif insn isa CheckClassEq
        quote
            if ctx[insn.eclass1] == ctx[insn.eclass2]
                interp_unstaged(G, insns, ctx, buf)
            end
        end
    end
end

function interp_unstaged(G, insns, eclass0)
        buf = []
        interp_unstaged(G, insns, Dict{Symbol,Any}([:start => eclass0]) , buf)
        return buf
end

Ok, that is kind of complicated but also kind of not. It’s may be more complicated to explain and read than it is to write.

Let’s stage it!

Staging the Intepreter

We know the Pattern at compile time, but we want to run matches many, many times quickly at runtime. It makes sense to partially evaluate things. Let’s look at a related example first.

An aside: Partial Evaluating Many For Loops

Consider a brute force search over 3 bits. You could write it like this

for i in (false,true)
    for j in (false, true)
        for k in (false,true)
            println([i,j,k])
        end
    end
end

Now suppose I want code that works for arbitrary number n of bits. A bit more complicated. One way of doing it is via recursion. If you know how to solve the n-1 bits problem, all you need to do is run that problem in a single extra loop.

function allbits(n, partial)
   if n == 0
        println(partial)
        return 
   end
   for i in (false,true)
       partial[n] = i
       allbits(n-1, partial)
   end
end
function allbits(n) 
    allbits(n, fill(false, n))
end

We can by adding quasiquotation annotations to this code actually produce the same code as above. We assume n is known at compile time. We also know the size of the partial vector, but the values inside the partial vector are no longer boolean values, instead they represent the code of boolean values available at runtime.
It is unfortunate that I need to gensym the i variable. Perhaps there are facilities in Julia to avoid this, perhaps not. Note however that the code is otherwise structurally identical to the above.

function allbits(n, partial::Vector{Symbol})
   if n == 0
        return :(println([$(partial...)]))  
   end
    i = gensym()
    quote
       for $i in (false,true)
           $(begin
                partial[n] = i
                allbits(n-1, partial)
            end)
       end
   end
end
function allbits(n) 
    allbits(n, fill(:foo, n))
end

prettify(allbits(3))
#=
:(for coati = (false, true)
      for caterpillar = (false, true)
          for seahorse = (false, true)
              println([seahorse, caterpillar, coati])
          end
      end
  end)
=#
# To use
eval(allbits(3))

Staged Interpeter

I just so have happened to have written the unstaged interpreter that very minimal changes are necessary to build the staged version.

Now the ctx is a compile time mapping of Insn variables to the expressions that at runtime will hold the appropriate values. G and buf are the runtime values that hold the egraph and vector of return variable bindings.
The way I did this is going through and thinking about what is avaiable at compile time and keeping it out of quotes, and what is available at runtime and putting it in quotes. It is a remarkably mechanical process to write this stuff once you get the hang of it. I suggest reading Oleg Kiselyov’s tutorials and mini-book for more.

Code = Union{Expr,Symbol}

function interp_staged(G::Code, insns, ctx, buf::Code) # G could be an argument or a global variable.
    if length(insns) == 0
        return 
    end
    insn = insns[1]
    insns = insns[2:end]
    if insn isa Bind
        enode = gensym(:enode)
        quote
            for $enode in $G[$(ctx[insn.eclass])] # we can store an integer of the eclass.
                if $enode.head == $(QuoteNode(insn.enodepat.head)) && length($enode.args) == $(length(insn.enodepat.args))
                    $(
                        begin
                        for (n,v) in enumerate(insn.enodepat.args)
                            ctx[v] = :($enode.args[$n])
                        end
                        interp_staged(G,  insns, ctx, buf)
                        end
                    )
                end
            end
        end
    elseif insn isa Yield
        quote
            push!( $buf, [ $( [ctx[y] for y in insn.yields]...) ])
        end
    elseif insn isa CheckClassEq
        quote
            if $(ctx[insn.eclass1]) == $(ctx[insn.eclass2])
                $(interp_staged(G, insns, ctx, buf))
            end
        end
    end
end

function interp_staged(insns)
    quote
        (G, eclass0) -> begin
            buf = []
            $(interp_staged(:G, insns, Dict{Symbol,Any}([:start => :eclass0]) , :buf))
            return buf
        end
    end
end

using MacroTools # for prettify
p1, ctx1 = compile_pat(@pat f(X))
println(prettify(interp_staged( p1  )))
#=
(G, eclass0)->begin
        buf = []
        for sardine = G[eclass0]
            if sardine.head == :f && length(sardine.args) == 1
                push!(buf, [sardine.args[1]])
            end
        end
        return buf
    end
=#

Bits and Bobbles

Example tests

@testset "Basic Compile" begin
    @test compile_pat(@pat f)[1] == [Bind(:start, ENodePat(:f, [])), Yield([])]

    p1, ctx1 = compile_pat(@pat f(X))
    @test p1 == [Bind(:start, ENodePat(:f,  [ctx1[:X]]  )), Yield([ctx1[:X]])]
    
    p1, ctx = compile_pat(@pat f(X, Y))
    @test p1 == [Bind(:start, ENodePat(:f,  [ctx[:X], ctx[:Y] ]  )), 
                 Yield([ctx[:X], ctx[:Y]])]
    @test ctx[:X] != ctx[:Y]
    @test length(ctx) == 2

    p1, ctx = compile_pat(@pat f(X, X))
    x2 = p1[1].enodepat.args[2]
    @test length(ctx) == 1
    @test p1 == [Bind(:start, ENodePat(:f,  [ctx[:X], x2 ]  )), 
                 CheckClassEq( x2 , ctx[:X] ),
                 Yield([ctx[:X]])]

    p1, ctx = compile_pat(@pat f(g(X)))
    g = p1[1].enodepat.args[1]
    @test length(ctx) == 1
    @test p1 == [Bind(:start, ENodePat(:f,  [g ]  )),
                Bind(g, ENodePat(:g,  [ctx[:X] ]  )) ,
                Yield([ctx[:X]])]

end

using MacroTools
@testset "Basic Match" begin
    G = EGraph()
    a = addexpr!(G, :a) 
    b = addexpr!(G, :b) 
    fa = addexpr!(G, :( f(a)  ))
    fb = addexpr!(G, :( f(b)  ))

    p1, ctx1 = compile_pat(@pat f(X))

    println(prettify(interp_staged( p1  )))
    match = eval(interp_staged( p1  ))
    @test match(G, fa) == [ [ a ] ]
    @test match(G, fb) == [ [ b ] ]
    union!(G, fa, fb)
    #println(G)
    @test match(G, fa) == [ [ b ], [ a ]  ]
    union!(G, a, b)
    #println(G)
    #println(congruences(G))
    propagate_congruence(G)
    #println(G)
    #println(congruences(G))
    rebuild!(G) # There is something kind of sick here.
    #println(G)
    a = addexpr!(G, :a)
    @test match(G, fa) == [ [ a ] ]


    p1, ctx = compile_pat(@pat f(g(X)))
    match = eval(interp_staged( p1  ))
    G = EGraph()
    a = addexpr!(G, :a) 
    b = addexpr!(G, :b) 
    fga = addexpr!(G, :( f(g(a))  ))
    fb = addexpr!(G, :( f(b)  ))
    @test match(G,fb) == []
    @test match(G,fga) == [[a]]
    union!(G, fb, fga)
    @test match(G,fb) == [[a]]
    @test match(G,fga) == [[a]]



    p1, ctx = compile_pat(@pat f(X, X))
    #println(prettify(interp_staged( p1  )))
    match = eval(interp_staged( p1  ))
    G = EGraph()
    a = addexpr!(G, :a) 
    b = addexpr!(G, :b) 
    fga = addexpr!(G, :( f(g(a))  ))
    fb = addexpr!(G, :( f(b)  ))
    fab = addexpr!(G, :( f(a,b)  ))
    @test match(G,fab) == []
    faa = addexpr!(G, :( f(a,a)  ))
    @test match(G,faa) == [[a]]
    @test match(G,fab) == []
    union!(G,a,b)
    rebuild!(G)
    #propagate_congruence(G)
    @test match(G,faa) == [[b]]
    @test match(G,fab) == [[b]]



    #println(prettify(interp_staged( p1  )))
end

Partial Evaluation of a Pattern Matcher for E-graphs

By: Philip Zucker

Re-posted from: https://www.philipzucker.com/staging-patterns/

Partial Evaluation is cool.

There are different interpretations of what partial evaluation means. I think the term can be applied to anything that you’re, well, partially evaluating. This might include really reaching down into and ripping apart the AST of a code block finding things like 1 + 7 and replacing them with 8.
However, my mental model of partial evaluation is a function that takes 2 arguments, one given at compile time and one at run time. Given the argument at compile time, you can hopefully do all sorts of decisions and control flow choices based off that information.
What I know of this, I’ve learned from reading Oleg Kiselyov http://okmij.org/ftp/meta-programming/index.html and his and other’s work in MetaOcaml. Sometimes this technique is called staged metaprogramming, or generative metaprogramming because you don’t really peek inside of Expr.

I wrote an implementation of a pattern matching in EGraphs that is slow. https://www.philipzucker.com/egraph-2/ I think partial evaluation in this style is a nice way to speed it up.

A useful pattern for organizing programming tasks is to build tiny programming languages (DSLs) and then write a function that takes a data structure containing a program written in this tiny language and interprets them over inputs. It is very often the case that the actual form of this program one is going to interpret is known at compile time.

A reasonable problem solving strategy is

  1. Try some maximally concrete examples and see what hand written code you would write
  2. Write an interpreter over some DSL describing the class of programs
  3. Stage the interpreter with quasiquotation, rearranging minimally as necessary.
  4. Check that your examples produce the same code or good enough compared to your hand written code.

So let’s start with 1.

Pattern matching feels a little scary, but when you actually sit down to do it on a concrete pattern, it ends up being a straightforward and boring pile of if-then-else clauses.

Here’s some straightforward examples of regular pattern matching

# pattern f(X)
function matchfx(e)
    if e isa Expr
        if e.head == :f and length(e.args) == 1
            return e.args[1]
        end
    end
end

# pattern f(X,Y)
function matchfxy(e)
    if e isa Expr
        if e.head == :f and length(e.args) == 2
            return [:X => e.args[1], :Y => e.args[2]]
        end
    end
end

EMatching throws a wrench into the simplicity that you need to return multiple matches and traverse a search through the multiple enodes per eclass. Nevertheless, for concrete patterns, it isn’t really that complicated either.

#pattern f(X)
function matchfx(G, e::Int64)
    buf = []
    for enode in G[e]
        if enode.head == :f && length(enode.args) == 1
            push!(buf, enode.args[1])
        end
    end
    return buf
end

#pattern f(X,Y) returns 2 variable bindings
function matchfxy(G, e::Int64)
buf = []
for enode in G[e]
    if enode.head == :f && length(enode.args) == 2
        push!(buf, [enode.arg[1], enode.arg[2]])
    end
end
return buf
end

#pattern f(X,X) requires a check 
function match(G, e::Int64)
    buf = []
    for enode in G[e]
        if enode.head == :f && enode.args[1] == enode.args[2]
            push!(buf, enode.args[1])
        end
    end
    return buf
end

#f(g(X)) requires deeper traversal. # of loops is proprtional to number of eclasses expanded
function matchfgx(G, e::Int64)
    buf = []
    for enode in G[e]
        if enode.head == :f and length(f.args) == 1
            for enode in G[f.args[1]]
                if enode.head == :g and length(enode.args) == 1
                    push!(buf, enode.args[0])
                end
            end
        end
    end
end

Ok, well how can we generate these?

An Interpreter

I just so happen to know from the egg https://egraphs-good.github.io/ paper and the de Moura, Bjorner paper http://leodemoura.github.io/files/ematching.pdf that it is smart to compile patterns into a sequence of instructions. This is also similar to what is done in Prolog with the Warren Abstract Machine. Sequences of instructions clarify the control flow compared to directly writing recursive functions to interpret patterns. This explicit control flow is subtly important for the ability to stage programs.

For what I’ve shown above we need 3 instructions

  1. Opening up an eclass. This makes a for loop
  2. Checking if two variables are the same like in f(X,X)
  3. Yielding a binding of variables

Here is a simple data type to describe patterns and a helper macro to construct them. It considers capital symbols to be variables.

@auto_hash_equals struct Pattern
    head::Symbol
    args
end

struct PatVar
    name::Symbol
end

pat(e::Symbol) = isuppercase(String(e)[1]) ? PatVar(e) : Pattern(e, [])
pat(e::Expr) = Pattern(e.args[1], pat.(e.args[2:end]) )
macro pat(e) 
    return pat(e)
end
# example usage: @pat f(g(X), X)

And here are the data types Bind, CheckClassEq, Yield of the instructions of my language corresponding to 1-3 above.

@auto_hash_equals struct ENodePat
    head::Symbol
    args::Vector{Symbol}
end

@auto_hash_equals struct Bind 
    eclass::Symbol
    enodepat::ENodePat
end

@auto_hash_equals struct CheckClassEq
    eclass1::Symbol
    eclass2::Symbol
end

@auto_hash_equals struct Yield
    yields::Vector{Symbol}
end

Here is a function that takes a Pattern and turns it into a sequence of instructions by traversing it depth first. This may not be the most optimal ordering to produce. During the pass I maintain a context from which you can lookup variables from the Pattern and find their new names that I bound in the Insns.

function compile_pat(eclass, p::Pattern,ctx)
    a = [gensym() for i in 1:length(p.args) ]
    return vcat(  Bind(eclass, ENodePat(p.head, a  )) , [compile_pat(eclass, p2, ctx) for (eclass , p2) in zip(a, p.args)]...)
end

function compile_pat(eclass, p::PatVar,ctx)
    if haskey(ctx, p.name)
        return CheckClassEq(eclass, ctx[p.name])
    else
        ctx[p.name] = eclass
        return []
    end
end
    
function compile_pat(p)
    ctx = Dict()
    insns = compile_pat(:start, p, ctx)
    
    return vcat(insns, Yield( [values(ctx)...]) ), ctx
end

Then I can write an interpreter over these instructions. You case on which instruction type. If it is a Bind you make a for loop over all ENodes in that eclass, check if the head matches the pattern’s head and the length is right, add all the new found enode numbers to the context, and finally recurse into the next instruction. If it is a CheckClassEq you check with the context to see if we should stop this branch of the search. If it is a Yield we add the current solution to the buf. The buf thing came up in Alessandro’s journey to optimize the e-matcher in Metatheory.jl https://github.com/0x0f0f0f/Metatheory.jl and it really cleared things up for me compared to my monstrous Channel based implementation.

# uh is this code even right? In all honesty I made it by de-staging the code below.
function interp_unstaged(G, insns, ctx, buf)
    if length(insns) == 0
        return 
    end
    insn = insns[1]
    insns = insns[2:end]
    if insn isa Bind
        for enode in G[ctx[insn.eclass]] 
            if enode.head == insn.enodepat.head && length(enode.args) == length(insn.enodepat.args)
                        for (n,v) in enumerate(insn.enodepat.args)
                            ctx[v] = enode.args[n]
                        end
                        interp_unstaged(G,  insns, ctx, buf)
                end
            end
        end
    elseif insn isa Yield
        quote
            push!( buf, [ctx[y] for y in insn.yields])
        end
    elseif insn isa CheckClassEq
        quote
            if ctx[insn.eclass1] == ctx[insn.eclass2]
                interp_unstaged(G, insns, ctx, buf)
            end
        end
    end
end

function interp_unstaged(G, insns, eclass0)
        buf = []
        interp_unstaged(G, insns, Dict{Symbol,Any}([:start => eclass0]) , buf)
        return buf
end

Ok, that is kind of complicated but also kind of not. It’s may be more complicated to explain and read than it is to write.

Let’s stage it!

Staging the Intepreter

We know the Pattern at compile time, but we want to run matches many, many times quickly at runtime. It makes sense to partially evaluate things. Let’s look at a related example first.

An aside: Partial Evaluating Many For Loops

Consider a brute force search over 3 bits. You could write it like this

for i in (false,true)
    for j in (false, true)
        for k in (false,true)
            println([i,j,k])
        end
    end
end

Now suppose I want code that works for arbitrary number n of bits. A bit more complicated. One way of doing it is via recursion. If you know how to solve the n-1 bits problem, all you need to do is run that problem in a single extra loop.

function allbits(n, partial)
   if n == 0
        println(partial)
        return 
   end
   for i in (false,true)
       partial[n] = i
       allbits(n-1, partial)
   end
end
function allbits(n) 
    allbits(n, fill(false, n))
end

We can by adding quasiquotation annotations to this code actually produce the same code as above. We assume n is known at compile time. We also know the size of the partial vector, but the values inside the partial vector are no longer boolean values, instead they represent the code of boolean values available at runtime.
It is unfortunate that I need to gensym the i variable. Perhaps there are facilities in Julia to avoid this, perhaps not. Note however that the code is otherwise structurally identical to the above.

function allbits(n, partial::Vector{Symbol})
   if n == 0
        return :(println([$(partial...)]))  
   end
    i = gensym()
    quote
       for $i in (false,true)
           $(begin
                partial[n] = i
                allbits(n-1, partial)
            end)
       end
   end
end
function allbits(n) 
    allbits(n, fill(:foo, n))
end

prettify(allbits(3))
#=
:(for coati = (false, true)
      for caterpillar = (false, true)
          for seahorse = (false, true)
              println([seahorse, caterpillar, coati])
          end
      end
  end)
=#
# To use
eval(allbits(3))

Staged Interpeter

I just so have happened to have written the unstaged interpreter that very minimal changes are necessary to build the staged version.

Now the ctx is a compile time mapping of Insn variables to the expressions that at runtime will hold the appropriate values. G and buf are the runtime values that hold the egraph and vector of return variable bindings.
The way I did this is going through and thinking about what is avaiable at compile time and keeping it out of quotes, and what is available at runtime and putting it in quotes. It is a remarkably mechanical process to write this stuff once you get the hang of it. I suggest reading Oleg Kiselyov’s tutorials and mini-book for more.

Code = Union{Expr,Symbol}

function interp_staged(G::Code, insns, ctx, buf::Code) # G could be an argument or a global variable.
    if length(insns) == 0
        return 
    end
    insn = insns[1]
    insns = insns[2:end]
    if insn isa Bind
        enode = gensym(:enode)
        quote
            for $enode in $G[$(ctx[insn.eclass])] # we can store an integer of the eclass.
                if $enode.head == $(QuoteNode(insn.enodepat.head)) && length($enode.args) == $(length(insn.enodepat.args))
                    $(
                        begin
                        for (n,v) in enumerate(insn.enodepat.args)
                            ctx[v] = :($enode.args[$n])
                        end
                        interp_staged(G,  insns, ctx, buf)
                        end
                    )
                end
            end
        end
    elseif insn isa Yield
        quote
            push!( $buf, [ $( [ctx[y] for y in insn.yields]...) ])
        end
    elseif insn isa CheckClassEq
        quote
            if $(ctx[insn.eclass1]) == $(ctx[insn.eclass2])
                $(interp_staged(G, insns, ctx, buf))
            end
        end
    end
end

function interp_staged(insns)
    quote
        (G, eclass0) -> begin
            buf = []
            $(interp_staged(:G, insns, Dict{Symbol,Any}([:start => :eclass0]) , :buf))
            return buf
        end
    end
end

using MacroTools # for prettify
p1, ctx1 = compile_pat(@pat f(X))
println(prettify(interp_staged( p1  )))
#=
(G, eclass0)->begin
        buf = []
        for sardine = G[eclass0]
            if sardine.head == :f && length(sardine.args) == 1
                push!(buf, [sardine.args[1]])
            end
        end
        return buf
    end
=#

Bits and Bobbles

Example tests

@testset "Basic Compile" begin
    @test compile_pat(@pat f)[1] == [Bind(:start, ENodePat(:f, [])), Yield([])]

    p1, ctx1 = compile_pat(@pat f(X))
    @test p1 == [Bind(:start, ENodePat(:f,  [ctx1[:X]]  )), Yield([ctx1[:X]])]
    
    p1, ctx = compile_pat(@pat f(X, Y))
    @test p1 == [Bind(:start, ENodePat(:f,  [ctx[:X], ctx[:Y] ]  )), 
                 Yield([ctx[:X], ctx[:Y]])]
    @test ctx[:X] != ctx[:Y]
    @test length(ctx) == 2

    p1, ctx = compile_pat(@pat f(X, X))
    x2 = p1[1].enodepat.args[2]
    @test length(ctx) == 1
    @test p1 == [Bind(:start, ENodePat(:f,  [ctx[:X], x2 ]  )), 
                 CheckClassEq( x2 , ctx[:X] ),
                 Yield([ctx[:X]])]

    p1, ctx = compile_pat(@pat f(g(X)))
    g = p1[1].enodepat.args[1]
    @test length(ctx) == 1
    @test p1 == [Bind(:start, ENodePat(:f,  [g ]  )),
                Bind(g, ENodePat(:g,  [ctx[:X] ]  )) ,
                Yield([ctx[:X]])]

end

using MacroTools
@testset "Basic Match" begin
    G = EGraph()
    a = addexpr!(G, :a) 
    b = addexpr!(G, :b) 
    fa = addexpr!(G, :( f(a)  ))
    fb = addexpr!(G, :( f(b)  ))

    p1, ctx1 = compile_pat(@pat f(X))

    println(prettify(interp_staged( p1  )))
    match = eval(interp_staged( p1  ))
    @test match(G, fa) == [ [ a ] ]
    @test match(G, fb) == [ [ b ] ]
    union!(G, fa, fb)
    #println(G)
    @test match(G, fa) == [ [ b ], [ a ]  ]
    union!(G, a, b)
    #println(G)
    #println(congruences(G))
    propagate_congruence(G)
    #println(G)
    #println(congruences(G))
    rebuild!(G) # There is something kind of sick here.
    #println(G)
    a = addexpr!(G, :a)
    @test match(G, fa) == [ [ a ] ]


    p1, ctx = compile_pat(@pat f(g(X)))
    match = eval(interp_staged( p1  ))
    G = EGraph()
    a = addexpr!(G, :a) 
    b = addexpr!(G, :b) 
    fga = addexpr!(G, :( f(g(a))  ))
    fb = addexpr!(G, :( f(b)  ))
    @test match(G,fb) == []
    @test match(G,fga) == [[a]]
    union!(G, fb, fga)
    @test match(G,fb) == [[a]]
    @test match(G,fga) == [[a]]



    p1, ctx = compile_pat(@pat f(X, X))
    #println(prettify(interp_staged( p1  )))
    match = eval(interp_staged( p1  ))
    G = EGraph()
    a = addexpr!(G, :a) 
    b = addexpr!(G, :b) 
    fga = addexpr!(G, :( f(g(a))  ))
    fb = addexpr!(G, :( f(b)  ))
    fab = addexpr!(G, :( f(a,b)  ))
    @test match(G,fab) == []
    faa = addexpr!(G, :( f(a,a)  ))
    @test match(G,faa) == [[a]]
    @test match(G,fab) == []
    union!(G,a,b)
    rebuild!(G)
    #propagate_congruence(G)
    @test match(G,faa) == [[b]]
    @test match(G,fab) == [[b]]



    #println(prettify(interp_staged( p1  )))
end

A Simplified E-graph Implementation

By: Philip Zucker

Re-posted from: https:/www.philipzucker.com/a-simplified-egraph/

I’ve been spending some time mulling over e-graphs. I think I have it kind of pared down until it’s fairly simple.
This implementation is probably not high performance as the main trick is removing a genuinely useful indexing structure. Still, this implementation is small enough that I can kind of keep it in my head. It has become rather cute.

For a user ready implementation of egraphs, see Metatheory https://github.com/0x0f0f0f/Metatheory.jl or egg https://egraphs-good.github.io/

In a computer, terms like (a + b) * c are typically stored as trees. The EGraph converts this tree structure to one with an extra layer of indirection. Instead of nodes directly connecting to nodes, they connect first through a different kind of node labelled by an integer. These integer nodes are called eclasses and the modified original nodes are called enodes.

The EGraph is a bipartite directed graph between eclasses and enodes. ENodes belong uniquely to eclasses. Once we start merging eclasses, eclasses will point to more than one enode. The graph may also develop loops allowing representation in a sense of infinite terms.

Last time https://www.philipzucker.com/union-find-dict/, I built the abstraction of a union-find dict. This data structure allows you to retrieve information keyed on an equivalence class while still supporting the union operation. Given this piece, it is simple to define the two data structures.

@auto_hash_equals struct ENode
    head::Symbol
    args::Vector{Int64}
end

struct EGraph
    eclass::IntDisjointMap
    memo::Dict{ENode, Int64}
end

EGraph() = EGraph( IntDisjointMap(union)  , Dict{ENode,Int64}())

The eclass field is a union-find dictionary from equivalence classes to vectors of ENodes. We tell the underlying IntDisjointMap that upon a union! of equivalence classes, we will union the enode vectors in the codomain of the IntDisjointMap to each other.

The memo table is not strictly necessary, but it gives us a good way to lookup which eclass an enode belongs to. Otherwise we’d have to brute force search the entire IntDisjointMap to find ENodes when we want them.

ENodes hold references to EClass ids, which unfortunately can go stale. We can canonize ENodes to use the freshest equivalence class indices.

canonize(G::EGraph, f::ENode) = ENode(f.head, [find_root(G.eclass, a) for a in f.args])

To add an ENode to the egraph first we canonize it, then we check if it already is in the memo table, and if not we actually push it in the IntDisjointMap and update the memo table.

function addenode!(G::EGraph, f::ENode)
    f = canonize(G,f)
    if haskey(G.memo, f)
        return G.memo[f]
    else
        id = push!(G.eclass, [f])
        G.memo[f] = id
        return id
    end
end

#convenience functions for pushing Julia Expr
addexpr!(G::EGraph, f::Symbol) = addenode!(G, ENode(f,[]))
function addexpr!(G::EGraph, f::Expr)
    @assert f.head == :call
    addenode!(G,  ENode(f.args[1], [ addexpr!(G,a) for a in f.args[2:end] ]  ))
end

When we assert an equality to an egraph, we take the union! of the two corresponding eclasses. We union! the underlying IntDisjointMap, then we recanonize all the held ENodes in that eclass and update the memo table.

function Base.union!(G::EGraph, f::Int64, g::Int64)
    id = union!(G.eclass, f, g)
    eclass = ENode[]
    for enode in G.eclass[id] 
        delete!(G.memo, enode) # should we even bother with delete?
        enode = canonize(G, enode) # should canonize return whether it did anything or not?
        G.memo[enode] = id
        push!(eclass, enode)
    end
    G.eclass[id] = eclass
end

That’s kind of it.

The big thing we haven’t discussed is calculating congruence closure. In my original presentation, this was a whole ordeal and the reason why we needed to maintain parent pointers from eclasses to enodes. This was very confusing.

Instead we can just find congruences via a brute force sweep over the egraph. This is inefficient compared to having likely candidates pointed out to us by the parent pointers. However, during naive ematching we are sweeping over the egraph anyway to find possible rewrite rules applications. This approach makes congruence closure feel rather similar to the other rewrite rules in the sense. There may be some utility in not considering congruence closure as a truly intrinsic part of the egraph. Perhaps you could use it for systems where congruence does not strictly hold?

An unfortunate thing is that congruences needs to be applied in worst case a number of time proportional to the depth of the egraph, as it only propagates congruences one layer at a time.

How it works: after a union! operation there are non canonical ENodes held in both memo and eclass. These noncanonical ENodes are exactly those who have arguments that include the eclass that was just turned into a child of another eclass. These are also exactly those ENodes that are candidates for congruence closure propagation. We can detect them during the sweep by canonization.

This expensive congruence sweep forgives more sins than the more efficient one. Something that can happen is that we try to addexpr! an ENode that is one of the stale ones, in other words it should be in the memo table but is not. This will falsely create a new eclass for this ENode. However, the congruence closure sweep will find this equivalence on the next pass.


# I forgot to include this IntDisjointMap iterator function in my last post.
# Conceptually it belongs there.
function Base.iterate(x::IntDisjointMap, state=1)
    while state <= length(x.parents)
        if x.parents[state] < 0
            return ((state, x.values[state]) , state + 1)
        end
        state += 1
    end
    return nothing
end

# returns a list of tuples of found congruences
function congruences(G::EGraph)
    buf = Tuple{Int64,Int64}[] 
    for (id1, eclass) in G.eclass #alternatively iterate over memo
        for enode in eclass
            cnode = canonize(G,enode)
            if haskey(G.memo, cnode)
                id2 = G.memo[cnode]
                if id1 != id2
                    push!(buf, (id1,id2))
                end
            end
        end
    end
    return buf
end

# propagate all congruences
function propagate_congruence(G::EGraph)
    cong = congruences(G)
    while length(cong) > 0
        for (i,j) in cong
            union!(G,i,j)
        end
        cong = congruences(G)
    end
end

Bits and Bobbles

In principle I think this formulation makes it easier to parallelize congruence finding alongside rewrite rule matching. The rewriting process becomes a swapsies between finding tuples to union and actually applying them.

Everything in this post could probably be tuned up to be more efficient.

To add analyses, you want to store a compound structure in the IntDisjointMap. Tuple{Vector{ENode}, Analysis) The merge operation then does both enode merge and analysis merge.

Possibly changing enodes to be binary might be nice. One can compile arbitrary arity into this. Then everything fits in place in the appropriate arrays, removing significant indirection

Uses of egraphs:

My other implementation

Some tests

using EGraphs
using Test

@testset "Basic EGraph" begin
G = EGraph()
a = addenode!(G, ENode(:a, []))
b = addenode!(G, ENode(:b, []))
#println(G)
union!(G, a, b)
#println(G)
@test addenode!(G, ENode(:a, [])) == 2
@test addenode!(G, ENode(:c, [])) == 3
@test addenode!(G, ENode(:f, [a])) == 4
union!(G, 3, 4)

#= println(G)
for (k,v) in G.eclass
    println(k,v)
end =#
G = EGraph()
a = addenode!(G, ENode(:a, []))
b = addenode!(G, ENode(:b, []))
c = addenode!(G, ENode(:c, []))
union!(G, a, b)
fa = addenode!(G, ENode(:f, [a])) 
fb = addenode!(G, ENode(:f, [b])) 
fc = addenode!(G, ENode(:f, [c])) 

ffa = addenode!(G, ENode(:f, [fa])) 
ffb = addenode!(G, ENode(:f, [fb])) 

@test congruences(G) == [(fa,fb)]


for (x,y) in congruences(G)
    union!(G,x,y)
end

@test congruences(G) == [(ffa,ffb)]

union!(G, a, c)

@test congruences(G) == [(fc,fb), (ffa,ffb)]

for (x,y) in congruences(G)
    union!(G,x,y)
end

@test congruences(G) == []


G = EGraph()
f5a = addexpr!(G, :( f(f(f(f(f(a)))))  ))
f2a = addexpr!(G, :( f(f(a))  ))
@test length(G.eclass) == 6
union!(G , f5a, f2a)
@test find_root(G,f5a) == find_root(G,f2a)
@test length(G.eclass) == 5
f5a = addexpr!(G, :( f(f(f(f(f(a)))))  ))
f2a = addexpr!(G, :( f(f(f(a)))  ))
@test length(G.eclass) == 5

G = EGraph()
f5a = addexpr!(G, :( f(f(f(f(f(a)))))  ))
fa = addexpr!(G, :( f(a)  ))
a = addexpr!(G, :a)
@test length(G.eclass) == 6
union!(G , fa, a)
@test find_root(G,fa) == find_root(G,a)

propagate_congruence(G)
@test length(G.eclass) == 1

G = EGraph()
ffa = addexpr!(G, :( f(f(a))  ))
f5a = addexpr!(G, :( f(f(f(f(f(a)))))  ))
a = addexpr!(G, :a)
@test length(G.eclass) == 6
union!(G , ffa, a)
@test find_root(G,ffa) == find_root(G,a)

propagate_congruence(G)
@test length(G.eclass) == 2



end