Author Archives: Philip Zucker

JuliaCon 2021 Talk on Metatheory.jl and Snippets From the Cutting Room Floor

By: Philip Zucker

Re-posted from: https:/www.philipzucker.com/juliacon-talk/

Alessandro very graciously invited me to join him giving a talk on Metatheory.jl, his package for rewriting and egraphs. Thanks Alessandro!

I keep saying to myself I should try to give more talks, but hoo boy is it stressful. I felt like the whole month of June I was freaking out. And yet, it came off just fine I think. I was worried for nothing. Whenever these things come up, I always think that this time I need to tie up all the loose ends that I don’t know the answer to. But these loose ends exist because they are hard or they need stewing time.

It’s lunacy and hubris to think I can fix it all, but I do it every time. We had way way too much material anyway!

Here’s a snippet of todos from my notes.

Metatheory talk:

  • PEGs
  • Rewriting functional programs – https://www.haskellforall.com/2013/12/equational-reasoning.html
  • Gries equational logic?
  • datalog?
  • Summation – WIP though.
  • Automating catlab translation – This is only half built
  • The ideas behind the catlab trasnlation
  • Rearranging linear algerba expressions
  • translate Egg Projects
  • rewriting aexpr + bexpr

Ridiculous. Here’s some of the bits and pieces I did make though.

Stream Fusion

I started writing down how to optimize iterator expressions. This is known as stream fusion This is an interesting application, especially if you extract for asymptotic complexity rather than ast size. This is a very cool application though

using Pkg
Pkg.activate("metatheory")
Pkg.add(url="https://github.com/0x0f0f0f/Metatheory.jl.git")
using Metatheory
using Metatheory.EGraphs
@metatheory_init ()

array_theory = @theory begin
    reverse(reverse(x)) => x
    map(f,reverse(x)) => reverse(map(f, x))
    map(f,fill(x,N)) => fill(apply(f,x), N) # hmm
    filter(f,reverse(x)) => reverse(filter(f,x))
    reverse(fill(x,N)) => fill(x,N) 
    #map(f,x)[n:m] = map(f,x[n:m]) # but does NOT commute with filter
    filter(f, fill(x,N)) => f(x) ? fill(x,N) : fill(x,0)
    filter(f, filter(g, x)) => filter(f && g, x) # using functional &&
    cat(fill(x,N),fill(x,M)) => fill(x,N + M)
    cat(map(f,x), map(f,y)) => map(f, cat(x,y))
    map(f, cat(x,y)) => cat(map(f,x), map(f,y)) 
    map(f,map(g,x)) => map(f  g, x)
    reverse( cat(x,y) ) => cat(reverse(y), reverse(x))
    sum(fill(x,N)) => x * N
    map(f,x)[n] => apply(f,x[n]) #?
    apply(f  g, x) => apply(f, apply(g, x))
    fill(x,N)[n] => x
    cumsum(fill(x,N)) => collect(x:x:(N*x))
    length(fill(x,N)) => N
end

#G = EGraph(:(a * b * c * d * e));
#G = EGraph(:( reverse(reverse(fill(10,20))) )) 
#G = EGraph(:( map(x -> 7 * x, reverse(cat(fill(13,40),fill(10,20))) )))
G = EGraph(:( map(x -> 7 * x, fill(3,4) )))
G = EGraph(:( map(x -> 7 * x, fill(3,4) )[1]))



report = saturate!(G,array_theory);
ex = extract!(G, astsize)

Links on stream fusion:

E-PEGs

E-PEGs https://www.cs.cornell.edu/~ross/publications/eqsat/ are a compiler IR format that plays with the egraph. It’s interestingly related to the above. The basic idea I think is to enhance the CFG into being effectively in a purely functional form. Gated Phi nodes really are basically functional if the else constructs, since they reify the path condition. The other kinds of funky nodes they use for loops can be though of as functional stream combinators. In some sense the values things take on inside loops is a stream of values.

I started translating the table from Tate’s thesis chapters 7/8 into metatheory.jl and julia but I didn’t finish. I now appreciate the cleverness of the @capture macro from MacroTools though.

using MacroTools

function translate_prog(p)
    @capture(p, function f_(args__) body_ end)
    translate_statement(body, Dict(zip(args,args)), 0)[:retvar]
end


function translate_statement(s, env, loop_level) # translate statement
    #println(s)
    if @capture(s, begin s1_; s2__ end) && length(s2) > 0
        return translate_statement( Expr(:block, s2...), translate_statement(s1, env, loop_level), loop_level)
    elseif @capture(s, x_ = e_)
        env2 = copy(env)
        env2[x] = translate_expr(e, env)
        return env2
    elseif @capture(s, if e_ s1_ else s2_ end)
        return phi( translate_expr(e, env), translate_statement(s1, env, loop_level) , translate_statement(s2, env, loop_level) )
    elseif @capture(s, while c_ b_ end)
        
    end
    println("no match", e)
end

function translate_expr(e, env)
    if @capture( e , op_(args__))
        args = [translate_expr(arg, env) for arg in args]
        :($op( $(args...) ))
    elseif e isa Number
        e
    else
        env[e]
    end
end

function combine(m1,m2,conflict)
    res = Dict()
    for (k,v) in m1
        res[k] = v
    end
    for (k,v) in m2
        if haskey(res,k)
            res[k] = conflict( m1[k], m2[k] )
        else
           res[k] = v     
        end
    end
    return res
end

function phi(n, env1, env2,)
    combine(env1,env2, (t,f) -> :(ϕ($n, $t, $f )  ))
end

Automating The Catlab Translation

https://www.philipzucker.com/metatheory-progress/
This is only half built. It sounded like Alessandro was on the case a while back so I got discouraged. I really do not remember what state this is in, but here it is to those curious and extremely bold.

using Pkg
Pkg.activate("metacat")
Pkg.add(url="git@github.com:AlgebraicJulia/Catlab.jl.git")
Pkg.add(url="https://github.com/0x0f0f0f/Metatheory.jl.git")
using Metatheory
using Metatheory.EGraphs

using Catlab
using Catlab.Theories

function convert_types(types)
    ret = []
    for typ in types
        for accessor in typ.params
            lhs = Expr(:call, accessor, Expr(:call , :type, Expr(:call,  typ.name, hom.params... )))
            push!(ret, :( $lhs => $accessor ))
        end
    end
    return ret
end


symbols(e::Expr) = length(e.args) > 1 ? reduce(union, symbols.(e.args[2:end]) ) : Set([])
symbols(s::Symbol) = Set([s])

find_field(typ::Expr, a) = findfirst(x -> x == a, typ.args[2:end])
find_field(typ::Symbol, a) = nothing
function find_context(context, a)
    for (f,typ) in context
        i = find_field(typ,a)
        if i != nothing
            return (f,typ,i)
        end
    end
end

find_field(:(Hom(A,B)), :C)
find_context(Dict(:f => :(Hom(A,B))), :A)

function build_type_finder(ctx, u) # builds functional form of extractor from context.
    ident, typ, pos = find_context(ctx, u) 
    u => :(proj($pos, type($ident )))
end

type_finder_dict(ctx, unknowns) = Dict([ build_type_finder(ctx, u) for u in unknowns])

replace(e::Symbol, ud) = begin
    #println(e)
    haskey(ud,e) ?  ( ud[e])  : e
end
replace(e::Expr, unknown_replace_dict) = begin
    Expr(e.head, e.args[1], [ replace(a, unknown_replace_dict) for a in e.args[2:end] ]... ) end

function type_terms(terms)
    ret = []
    for term in terms
        lhs = Expr(:call , :type, Expr(:call,  term.name, term.params... ))
        
        
        known = Set(term.params)
        unknowns = setdiff(symbols(term.typ) , known)

        
        unknown_replace_dict = type_finder_dict(term.context, unknowns)
        builtterm = replace(term.typ, unknown_replace_dict)
        #println(unknown_replace_dict)
        #println(builtterm)
        #=
        typ_map = Dict()
        
        while !isempty(unknowns)
            u = pop!(unknowns)
            push!(known, u)
            type_u = term.context[u]
            typ_map[u] = type_u
            unknowns = unknowns ∪ (setdiff(symbols(type_u) , known ))
        end
        println(typ_map)
        =#
        push!(ret, :( $lhs => $(builtterm)))
    end
    return ret
end
type_terms( theory(Category).terms )

# If we want to go with the accessor encoding, We need to lookup parameters on the right hand side


# I should possible use metatheory to do replacement


# should I make these function singular and just map them?
function convert_axioms(axioms)
    ret = []
    for axiom in axioms
        #lhs = Expr(:call , :type, Expr(:call,  term.name, term.params... ))
        #println(axiom)

        leftsyms = symbols(axiom.left)
        rightsyms = symbols(axiom.right)
        # left to right
        unknowns = setdiff(rightsyms, leftsyms)
        #println(unknowns)
        #println([ find_context(axiom.context, u) for u in unknowns])
        d = type_finder_dict(axiom.context, unknowns)
        newright = replace(axiom.right, d)
        push!(ret, :( $(axiom.left) => $(newright)))
        
        
        unknowns = setdiff(leftsyms, rightsyms)
        d = type_finder_dict(axiom.context, unknowns)
        newleft = replace(axiom.left, d)
        push!(ret, :( $(axiom.right) => $(newleft)))
    end
    return ret
end
#convert_axioms( theory(Category).axioms)
convert_axioms( theory(MonoidalCategory).axioms)
#theory(MonoidalCategory).axioms


function find_term(termcons, n)
    for termcon in termcons
        if termcon.name == n
            return termcon
        end
    end
end

function typing_equuations(theory,s::Symbol)
    return []
end
function typing_equations(theory, e::Expr)
    @assert e.head == :call
    name = e.args[1]
    term_con = find_term(theory.terms, name)
    println(e)
    #freshparams = [p => gensym(p)  for p in term_con.params ]    
    #rec_equations = [ kv[2] => a  for (kv,a)  in   zip(freshparams, e.args[2:end]) ]
    rec_equations = [k => v for (k,v) in zip(term_con.params, e.args[2:end])]
    r2 = [  replace( k, Dict(rec_equations)) => replace( v, Dict(rec_equations))  for (k,v) in term_con.context]
    r3 = [ typing_equations(theory, a) for a in e.args[2:end] ]
    return vcat(r2,r3)
    
end
typing_equations( theory(MonoidalCategory), :(otimes(f,id(a))) )

JuliaCon 2021 Talk on Metatheory.jl and Snippets From the Cutting Room Floor

By: Philip Zucker

Re-posted from: https://www.philipzucker.com/juliacon-talk/

Alessandro very graciously invited me to join him giving a talk on Metatheory.jl, his package for rewriting and egraphs. Thanks Alessandro!

I keep saying to myself I should try to give more talks, but hoo boy is it stressful. I felt like the whole month of June I was freaking out. And yet, it came off just fine I think. I was worried for nothing. Whenever these things come up, I always think that this time I need to tie up all the loose ends that I don’t know the answer to. But these loose ends exist because they are hard or they need stewing time.

It’s lunacy and hubris to think I can fix it all, but I do it every time. We had way way too much material anyway!

Here’s a snippet of todos from my notes.

Metatheory talk:

  • PEGs
  • Rewriting functional programs – https://www.haskellforall.com/2013/12/equational-reasoning.html
  • Gries equational logic?
  • datalog?
  • Summation – WIP though.
  • Automating catlab translation – This is only half built
  • The ideas behind the catlab trasnlation
  • Rearranging linear algerba expressions
  • translate Egg Projects
  • rewriting aexpr + bexpr

Ridiculous. Here’s some of the bits and pieces I did make though.

Stream Fusion

I started writing down how to optimize iterator expressions. This is known as stream fusion This is an interesting application, especially if you extract for asymptotic complexity rather than ast size. This is a very cool application though

using Pkg
Pkg.activate("metatheory")
Pkg.add(url="https://github.com/0x0f0f0f/Metatheory.jl.git")
using Metatheory
using Metatheory.EGraphs
@metatheory_init ()

array_theory = @theory begin
    reverse(reverse(x)) => x
    map(f,reverse(x)) => reverse(map(f, x))
    map(f,fill(x,N)) => fill(apply(f,x), N) # hmm
    filter(f,reverse(x)) => reverse(filter(f,x))
    reverse(fill(x,N)) => fill(x,N) 
    #map(f,x)[n:m] = map(f,x[n:m]) # but does NOT commute with filter
    filter(f, fill(x,N)) => f(x) ? fill(x,N) : fill(x,0)
    filter(f, filter(g, x)) => filter(f && g, x) # using functional &&
    cat(fill(x,N),fill(x,M)) => fill(x,N + M)
    cat(map(f,x), map(f,y)) => map(f, cat(x,y))
    map(f, cat(x,y)) => cat(map(f,x), map(f,y)) 
    map(f,map(g,x)) => map(f  g, x)
    reverse( cat(x,y) ) => cat(reverse(y), reverse(x))
    sum(fill(x,N)) => x * N
    map(f,x)[n] => apply(f,x[n]) #?
    apply(f  g, x) => apply(f, apply(g, x))
    fill(x,N)[n] => x
    cumsum(fill(x,N)) => collect(x:x:(N*x))
    length(fill(x,N)) => N
end

#G = EGraph(:(a * b * c * d * e));
#G = EGraph(:( reverse(reverse(fill(10,20))) )) 
#G = EGraph(:( map(x -> 7 * x, reverse(cat(fill(13,40),fill(10,20))) )))
G = EGraph(:( map(x -> 7 * x, fill(3,4) )))
G = EGraph(:( map(x -> 7 * x, fill(3,4) )[1]))



report = saturate!(G,array_theory);
ex = extract!(G, astsize)

Links on stream fusion:

E-PEGs

E-PEGs https://www.cs.cornell.edu/~ross/publications/eqsat/ are a compiler IR format that plays with the egraph. It’s interestingly related to the above. The basic idea I think is to enhance the CFG into being effectively in a purely functional form. Gated Phi nodes really are basically functional if the else constructs, since they reify the path condition. The other kinds of funky nodes they use for loops can be though of as functional stream combinators. In some sense the values things take on inside loops is a stream of values.

I started translating the table from Tate’s thesis chapters 7/8 into metatheory.jl and julia but I didn’t finish. I now appreciate the cleverness of the @capture macro from MacroTools though.

using MacroTools

function translate_prog(p)
    @capture(p, function f_(args__) body_ end)
    translate_statement(body, Dict(zip(args,args)), 0)[:retvar]
end


function translate_statement(s, env, loop_level) # translate statement
    #println(s)
    if @capture(s, begin s1_; s2__ end) && length(s2) > 0
        return translate_statement( Expr(:block, s2...), translate_statement(s1, env, loop_level), loop_level)
    elseif @capture(s, x_ = e_)
        env2 = copy(env)
        env2[x] = translate_expr(e, env)
        return env2
    elseif @capture(s, if e_ s1_ else s2_ end)
        return phi( translate_expr(e, env), translate_statement(s1, env, loop_level) , translate_statement(s2, env, loop_level) )
    elseif @capture(s, while c_ b_ end)
        
    end
    println("no match", e)
end

function translate_expr(e, env)
    if @capture( e , op_(args__))
        args = [translate_expr(arg, env) for arg in args]
        :($op( $(args...) ))
    elseif e isa Number
        e
    else
        env[e]
    end
end

function combine(m1,m2,conflict)
    res = Dict()
    for (k,v) in m1
        res[k] = v
    end
    for (k,v) in m2
        if haskey(res,k)
            res[k] = conflict( m1[k], m2[k] )
        else
           res[k] = v     
        end
    end
    return res
end

function phi(n, env1, env2,)
    combine(env1,env2, (t,f) -> :(ϕ($n, $t, $f )  ))
end

Automating The Catlab Translation

https://www.philipzucker.com/metatheory-progress/
This is only half built. It sounded like Alessandro was on the case a while back so I got discouraged. I really do not remember what state this is in, but here it is to those curious and extremely bold.

using Pkg
Pkg.activate("metacat")
Pkg.add(url="git@github.com:AlgebraicJulia/Catlab.jl.git")
Pkg.add(url="https://github.com/0x0f0f0f/Metatheory.jl.git")
using Metatheory
using Metatheory.EGraphs

using Catlab
using Catlab.Theories

function convert_types(types)
    ret = []
    for typ in types
        for accessor in typ.params
            lhs = Expr(:call, accessor, Expr(:call , :type, Expr(:call,  typ.name, hom.params... )))
            push!(ret, :( $lhs => $accessor ))
        end
    end
    return ret
end


symbols(e::Expr) = length(e.args) > 1 ? reduce(union, symbols.(e.args[2:end]) ) : Set([])
symbols(s::Symbol) = Set([s])

find_field(typ::Expr, a) = findfirst(x -> x == a, typ.args[2:end])
find_field(typ::Symbol, a) = nothing
function find_context(context, a)
    for (f,typ) in context
        i = find_field(typ,a)
        if i != nothing
            return (f,typ,i)
        end
    end
end

find_field(:(Hom(A,B)), :C)
find_context(Dict(:f => :(Hom(A,B))), :A)

function build_type_finder(ctx, u) # builds functional form of extractor from context.
    ident, typ, pos = find_context(ctx, u) 
    u => :(proj($pos, type($ident )))
end

type_finder_dict(ctx, unknowns) = Dict([ build_type_finder(ctx, u) for u in unknowns])

replace(e::Symbol, ud) = begin
    #println(e)
    haskey(ud,e) ?  ( ud[e])  : e
end
replace(e::Expr, unknown_replace_dict) = begin
    Expr(e.head, e.args[1], [ replace(a, unknown_replace_dict) for a in e.args[2:end] ]... ) end

function type_terms(terms)
    ret = []
    for term in terms
        lhs = Expr(:call , :type, Expr(:call,  term.name, term.params... ))
        
        
        known = Set(term.params)
        unknowns = setdiff(symbols(term.typ) , known)

        
        unknown_replace_dict = type_finder_dict(term.context, unknowns)
        builtterm = replace(term.typ, unknown_replace_dict)
        #println(unknown_replace_dict)
        #println(builtterm)
        #=
        typ_map = Dict()
        
        while !isempty(unknowns)
            u = pop!(unknowns)
            push!(known, u)
            type_u = term.context[u]
            typ_map[u] = type_u
            unknowns = unknowns ∪ (setdiff(symbols(type_u) , known ))
        end
        println(typ_map)
        =#
        push!(ret, :( $lhs => $(builtterm)))
    end
    return ret
end
type_terms( theory(Category).terms )

# If we want to go with the accessor encoding, We need to lookup parameters on the right hand side


# I should possible use metatheory to do replacement


# should I make these function singular and just map them?
function convert_axioms(axioms)
    ret = []
    for axiom in axioms
        #lhs = Expr(:call , :type, Expr(:call,  term.name, term.params... ))
        #println(axiom)

        leftsyms = symbols(axiom.left)
        rightsyms = symbols(axiom.right)
        # left to right
        unknowns = setdiff(rightsyms, leftsyms)
        #println(unknowns)
        #println([ find_context(axiom.context, u) for u in unknowns])
        d = type_finder_dict(axiom.context, unknowns)
        newright = replace(axiom.right, d)
        push!(ret, :( $(axiom.left) => $(newright)))
        
        
        unknowns = setdiff(leftsyms, rightsyms)
        d = type_finder_dict(axiom.context, unknowns)
        newleft = replace(axiom.left, d)
        push!(ret, :( $(axiom.right) => $(newleft)))
    end
    return ret
end
#convert_axioms( theory(Category).axioms)
convert_axioms( theory(MonoidalCategory).axioms)
#theory(MonoidalCategory).axioms


function find_term(termcons, n)
    for termcon in termcons
        if termcon.name == n
            return termcon
        end
    end
end

function typing_equuations(theory,s::Symbol)
    return []
end
function typing_equations(theory, e::Expr)
    @assert e.head == :call
    name = e.args[1]
    term_con = find_term(theory.terms, name)
    println(e)
    #freshparams = [p => gensym(p)  for p in term_con.params ]    
    #rec_equations = [ kv[2] => a  for (kv,a)  in   zip(freshparams, e.args[2:end]) ]
    rec_equations = [k => v for (k,v) in zip(term_con.params, e.args[2:end])]
    r2 = [  replace( k, Dict(rec_equations)) => replace( v, Dict(rec_equations))  for (k,v) in term_con.context]
    r3 = [ typing_equations(theory, a) for a in e.args[2:end] ]
    return vcat(r2,r3)
    
end
typing_equations( theory(MonoidalCategory), :(otimes(f,id(a))) )

Partial Evaluation of a Pattern Matcher for E-graphs

By: Philip Zucker

Re-posted from: https:/www.philipzucker.com/staging-patterns/

Partial Evaluation is cool.

There are different interpretations of what partial evaluation means. I think the term can be applied to anything that you’re, well, partially evaluating. This might include really reaching down into and ripping apart the AST of a code block finding things like 1 + 7 and replacing them with 8.
However, my mental model of partial evaluation is a function that takes 2 arguments, one given at compile time and one at run time. Given the argument at compile time, you can hopefully do all sorts of decisions and control flow choices based off that information.
What I know of this, I’ve learned from reading Oleg Kiselyov http://okmij.org/ftp/meta-programming/index.html and his and other’s work in MetaOcaml. Sometimes this technique is called staged metaprogramming, or generative metaprogramming because you don’t really peek inside of Expr.

I wrote an implementation of a pattern matching in EGraphs that is slow. https://www.philipzucker.com/egraph-2/ I think partial evaluation in this style is a nice way to speed it up.

A useful pattern for organizing programming tasks is to build tiny programming languages (DSLs) and then write a function that takes a data structure containing a program written in this tiny language and interprets them over inputs. It is very often the case that the actual form of this program one is going to interpret is known at compile time.

A reasonable problem solving strategy is

  1. Try some maximally concrete examples and see what hand written code you would write
  2. Write an interpreter over some DSL describing the class of programs
  3. Stage the interpreter with quasiquotation, rearranging minimally as necessary.
  4. Check that your examples produce the same code or good enough compared to your hand written code.

So let’s start with 1.

Pattern matching feels a little scary, but when you actually sit down to do it on a concrete pattern, it ends up being a straightforward and boring pile of if-then-else clauses.

Here’s some straightforward examples of regular pattern matching

# pattern f(X)
function matchfx(e)
    if e isa Expr
        if e.head == :f and length(e.args) == 1
            return e.args[1]
        end
    end
end

# pattern f(X,Y)
function matchfxy(e)
    if e isa Expr
        if e.head == :f and length(e.args) == 2
            return [:X => e.args[1], :Y => e.args[2]]
        end
    end
end

EMatching throws a wrench into the simplicity that you need to return multiple matches and traverse a search through the multiple enodes per eclass. Nevertheless, for concrete patterns, it isn’t really that complicated either.

#pattern f(X)
function matchfx(G, e::Int64)
    buf = []
    for enode in G[e]
        if f.head == :f && length(f.args) == 1
            push!(buf, f.args[1])
        end
    end
    return buf
end

#pattern f(X,Y) returns 2 variable bindings
function matchfxy(G, e::Int64)
buf = []
for enode in G[e]
    if f.head == :f && length(f.args) == 2
        push!(buf, [f.arg[1], f.arg[2]])
    end
end
return buf
end

#pattern f(X,X) requires a check 
function match(G, e::Int64)
    buf = []
    for enode in G[e]
        if f.head == :f && f.args[1] == f.args[2]
            push!(buf, f.args[1])
        end
    end
    return buf
end

#f(g(X)) requires deeper traversal. # of loops is proprtional to number of eclasses expanded
function matchfgx(G, e::Int64)
    buf = []
    for enode in G[e]
        if enode.head == :f and length(f.args) == 1
            for enode in G[f.args[1]]
                if enode.head == :g and length(enode.args) == 1
                    push!(buf, enode.args[0])
                end
            end
        end
    end
end

Ok, well how can we generate these?

An Interpreter

I just so happen to know from the egg https://egraphs-good.github.io/ paper and the de Moura, Bjorner paper http://leodemoura.github.io/files/ematching.pdf that it is smart to compile patterns into a sequence of instructions. This is also similar to what is done in Prolog with the Warren Abstract Machine. Sequences of instructions clarify the control flow compared to directly writing recursive functions to interpret patterns. This explicit control flow is subtly important for the ability to stage programs.

For what I’ve shown above we need 3 instructions

  1. Opening up an eclass. This makes a for loop
  2. Checking if two variables are the same like in f(X,X)
  3. Yielding a binding of variables

Here is a simple data type to describe patterns and a helper macro to construct them. It considers capital symbols to be variables.

@auto_hash_equals struct Pattern
    head::Symbol
    args
end

struct PatVar
    name::Symbol
end

pat(e::Symbol) = isuppercase(String(e)[1]) ? PatVar(e) : Pattern(e, [])
pat(e::Expr) = Pattern(e.args[1], pat.(e.args[2:end]) )
macro pat(e) 
    return pat(e)
end
# example usage: @pat f(g(X), X)

And here are the data types Bind, CheckClassEq, Yield of the instructions of my language corresponding to 1-3 above.

@auto_hash_equals struct ENodePat
    head::Symbol
    args::Vector{Symbol}
end

@auto_hash_equals struct Bind 
    eclass::Symbol
    enodepat::ENodePat
end

@auto_hash_equals struct CheckClassEq
    eclass1::Symbol
    eclass2::Symbol
end

@auto_hash_equals struct Yield
    yields::Vector{Symbol}
end

Here is a function that takes a Pattern and turns it into a sequence of instructions by traversing it depth first. This may not be the most optimal ordering to produce. During the pass I maintain a context from which you can lookup variables from the Pattern and find their new names that I bound in the Insns.

function compile_pat(eclass, p::Pattern,ctx)
    a = [gensym() for i in 1:length(p.args) ]
    return vcat(  Bind(eclass, ENodePat(p.head, a  )) , [compile_pat(eclass, p2, ctx) for (eclass , p2) in zip(a, p.args)]...)
end

function compile_pat(eclass, p::PatVar,ctx)
    if haskey(ctx, p.name)
        return CheckClassEq(eclass, ctx[p.name])
    else
        ctx[p.name] = eclass
        return []
    end
end
    
function compile_pat(p)
    ctx = Dict()
    insns = compile_pat(:start, p, ctx)
    
    return vcat(insns, Yield( [values(ctx)...]) ), ctx
end

Then I can write an interpreter over these instructions. You case on which instruction type. If it is a Bind you make a for loop over all ENodes in that eclass, check if the head matches the pattern’s head and the length is right, add all the new found enode numbers to the context, and finally recurse into the next instruction. If it is a CheckClassEq you check with the context to see if we should stop this branch of the search. If it is a Yield we add the current solution to the buf. The buf thing came up in Alessandro’s journey to optimize the ematcher in Metatheory.jl https://github.com/0x0f0f0f/Metatheory.jl and it really cleared things up for me compared to my monstrous Channel based implementation.

function interp_unstaged(G, insns, ctx, buf) 
    if length(insns) == 0
        return 
    end
    insn = insns[1]
    insns = insns[2:end]
    if insn isa Bind
        for enode in G[ctx[insn.eclass]] 
            if enode.head == insn.enodepat.head && length(enode.args) == length(insn.enodepat.args)
                        for (n,v) in enumerate(insn.enodepat.args)
                            ctx[v] = enode.args[n]
                        end
                        interp_unstaged(G,  insns, ctx, buf)
                end
            end
        end
    elseif insn isa Yield
        quote
            push!( buf, [ctx[y] for y in insn.yields])
        end
    elseif insn isa CheckClassEq
        quote
            if ctx[insn.eclass1] == ctx[insn.eclass2]
                interp_unstaged(G, insns, ctx, buf)
            end
        end
    end
end

function interp_unstaged(G, insns, eclass0)
        buf = []
        interp_unstaged(G, insns, Dict{Symbol,Any}([:start => eclass0]) , buf)
        return buf
end

Ok, that is kind of complicated but also kind of not. It’s may be more complicated to explain and read than it is to write.

Let’s stage it!

Staging the Intepreter

We know the Pattern at compile time, but we want to run matches many, many times quickly at runtime. It makes sense to partially evaluate things. Let’s look at a related example first.

An aside: Partial Evaluating Many For Loops

Consider a brute force search over 3 bits. You could write it like this

for i in (false,true)
    for j in (false, true)
        for k in (false,true)
            println([i,j,k])
        end
    end
end

Now suppose I want code that works for arbitrary number n of bits. A bit more complicated. One way of doing it is via recursion. If you know how to solve the n-1 bits problem, all you need to do is run that problem in a single extra loop.

function allbits(n, partial)
   if n == 0
        println(partial)
        return 
   end
   for i in (false,true)
       partial[n] = i
       allbits(n-1, partial)
   end
end
function allbits(n) 
    allbits(n, fill(false, n))
end

We can by adding quasiquotation annotations to this code actually produce the same code as above. We assume n is known at compile time. We also know the size of the partial vector, but the values inside the partial vector are no longer boolean values, instead they represent the code of boolean values available at runtime.
It is unfortunate that I need to gensym the i variable. Perhaps there are facilities in Julia to avoid this, perhaps not. Note however that the code is otherwise structurally identical to the above.

function allbits(n, partial::Vector{Symbol})
   if n == 0
        return :(println([$(partial...)]))  
   end
    i = gensym()
    quote
       for $i in (false,true)
           $(begin
                partial[n] = i
                allbits(n-1, partial)
            end)
       end
   end
end
function allbits(n) 
    allbits(n, fill(:foo, n))
end

prettify(allbits(3))
#=
:(for coati = (false, true)
      for caterpillar = (false, true)
          for seahorse = (false, true)
              println([seahorse, caterpillar, coati])
          end
      end
  end)
=#
# To use
eval(allbits(3))

Staged Interpeter

I just so have happened to have written the unstaged interpreter that very minimal changes are necessary to build the staged version.

Now the ctx is a compile time mapping of Insn variables to the expressions that at runtime will hold the appropriate values. G and buf are the runtime values that hold the egraph and vector of return variable bindings.
The way I did this is going through and thinking about what is avaiable at compile time and keeping it out of quotes, and what is available at runtime and putting it in quotes. It is a remarkably mechanical process to write this stuff once you get the hang of it. I suggest reading Oleg Kiselyov’s tutorials and mini-book for more.

Code = Union{Expr,Symbol}

function interp_staged(G::Code, insns, ctx, buf::Code) # G could be an argument or a global variable.
    if length(insns) == 0
        return 
    end
    insn = insns[1]
    insns = insns[2:end]
    if insn isa Bind
        enode = gensym(:enode)
        quote
            for $enode in $G[$(ctx[insn.eclass])] # we can store an integer of the eclass.
                if $enode.head == $(QuoteNode(insn.enodepat.head)) && length($enode.args) == $(length(insn.enodepat.args))
                    $(
                        begin
                        for (n,v) in enumerate(insn.enodepat.args)
                            ctx[v] = :($enode.args[$n])
                        end
                        interp_staged(G,  insns, ctx, buf)
                        end
                    )
                end
            end
        end
    elseif insn isa Yield
        quote
            push!( $buf, [ $( [ctx[y] for y in insn.yields]...) ])
        end
    elseif insn isa CheckClassEq
        quote
            if $(ctx[insn.eclass1]) == $(ctx[insn.eclass2])
                $(interp_staged(G, insns, ctx, buf))
            end
        end
    end
end

function interp_staged(insns)
    quote
        (G, eclass0) -> begin
            buf = []
            $(interp_staged(:G, insns, Dict{Symbol,Any}([:start => :eclass0]) , :buf))
            return buf
        end
    end
end

using MacroTools # for prettify
p1, ctx1 = compile_pat(@pat f(X))
println(prettify(interp_staged( p1  )))
#=
(G, eclass0)->begin
        buf = []
        for sardine = G[eclass0]
            if sardine.head == :f && length(sardine.args) == 1
                push!(buf, [sardine.args[1]])
            end
        end
        return buf
    end
=#

Bits and Bobbles

Example tests

@testset "Basic Compile" begin
    @test compile_pat(@pat f)[1] == [Bind(:start, ENodePat(:f, [])), Yield([])]

    p1, ctx1 = compile_pat(@pat f(X))
    @test p1 == [Bind(:start, ENodePat(:f,  [ctx1[:X]]  )), Yield([ctx1[:X]])]
    
    p1, ctx = compile_pat(@pat f(X, Y))
    @test p1 == [Bind(:start, ENodePat(:f,  [ctx[:X], ctx[:Y] ]  )), 
                 Yield([ctx[:X], ctx[:Y]])]
    @test ctx[:X] != ctx[:Y]
    @test length(ctx) == 2

    p1, ctx = compile_pat(@pat f(X, X))
    x2 = p1[1].enodepat.args[2]
    @test length(ctx) == 1
    @test p1 == [Bind(:start, ENodePat(:f,  [ctx[:X], x2 ]  )), 
                 CheckClassEq( x2 , ctx[:X] ),
                 Yield([ctx[:X]])]

    p1, ctx = compile_pat(@pat f(g(X)))
    g = p1[1].enodepat.args[1]
    @test length(ctx) == 1
    @test p1 == [Bind(:start, ENodePat(:f,  [g ]  )),
                Bind(g, ENodePat(:g,  [ctx[:X] ]  )) ,
                Yield([ctx[:X]])]

end

using MacroTools
@testset "Basic Match" begin
    G = EGraph()
    a = addexpr!(G, :a) 
    b = addexpr!(G, :b) 
    fa = addexpr!(G, :( f(a)  ))
    fb = addexpr!(G, :( f(b)  ))

    p1, ctx1 = compile_pat(@pat f(X))

    println(prettify(interp_staged( p1  )))
    match = eval(interp_staged( p1  ))
    @test match(G, fa) == [ [ a ] ]
    @test match(G, fb) == [ [ b ] ]
    union!(G, fa, fb)
    #println(G)
    @test match(G, fa) == [ [ b ], [ a ]  ]
    union!(G, a, b)
    #println(G)
    #println(congruences(G))
    propagate_congruence(G)
    #println(G)
    #println(congruences(G))
    rebuild!(G) # There is something kind of sick here.
    #println(G)
    a = addexpr!(G, :a)
    @test match(G, fa) == [ [ a ] ]


    p1, ctx = compile_pat(@pat f(g(X)))
    match = eval(interp_staged( p1  ))
    G = EGraph()
    a = addexpr!(G, :a) 
    b = addexpr!(G, :b) 
    fga = addexpr!(G, :( f(g(a))  ))
    fb = addexpr!(G, :( f(b)  ))
    @test match(G,fb) == []
    @test match(G,fga) == [[a]]
    union!(G, fb, fga)
    @test match(G,fb) == [[a]]
    @test match(G,fga) == [[a]]



    p1, ctx = compile_pat(@pat f(X, X))
    #println(prettify(interp_staged( p1  )))
    match = eval(interp_staged( p1  ))
    G = EGraph()
    a = addexpr!(G, :a) 
    b = addexpr!(G, :b) 
    fga = addexpr!(G, :( f(g(a))  ))
    fb = addexpr!(G, :( f(b)  ))
    fab = addexpr!(G, :( f(a,b)  ))
    @test match(G,fab) == []
    faa = addexpr!(G, :( f(a,a)  ))
    @test match(G,faa) == [[a]]
    @test match(G,fab) == []
    union!(G,a,b)
    rebuild!(G)
    #propagate_congruence(G)
    @test match(G,faa) == [[b]]
    @test match(G,fab) == [[b]]



    #println(prettify(interp_staged( p1  )))
end