diff --git a/src/Patterns.jl b/src/Patterns.jl index 2084cc63..4b595b66 100644 --- a/src/Patterns.jl +++ b/src/Patterns.jl @@ -115,6 +115,10 @@ PatExpr(iscall, op, args::Vector) = PatExpr(iscall, op, maybe_quote_operation(op isground(p::PatExpr)::Bool = p.isground +function Base.isequal(x::PatExpr, y::PatExpr) + x.head_hash === y.head_hash && v_signature(x.n)===v_signature(y.n) && all(x.children .== y.children) +end + TermInterface.isexpr(::PatExpr) = true TermInterface.head(p::PatExpr) = p.head TermInterface.operation(p::PatExpr) = p.head diff --git a/src/Rules.jl b/src/Rules.jl index b631c5c9..634b5cba 100644 --- a/src/Rules.jl +++ b/src/Rules.jl @@ -4,9 +4,9 @@ using TermInterface using AutoHashEquals using Metatheory.Patterns using Metatheory.Patterns: to_expr -using Metatheory: OptBuffer +using Metatheory: OptBuffer, match_compile -export RewriteRule, DirectedRule, EqualityRule, UnequalRule, DynamicRule, -->, is_bidirectional, Theory +export RewriteRule, DirectedRule, EqualityRule, UnequalRule, DynamicRule, -->, is_bidirectional, Theory, direct, direct_left_to_right, direct_right_to_left const STACK_SIZE = 512 @@ -140,6 +140,42 @@ instantiate_arg!(acc, left, parg::AbstractPat, bindings) = push!(acc, instantiat instantiate(_, pat::PatLiteral, bindings) = pat.value instantiate(_, pat::Union{PatVar,PatSegment}, bindings) = bindings[pat.idx] +"Inverts the direction of a rewrite rule, swapping the LHS and the RHS" +function invert(r::RewriteRule) + RewriteRule( + name = r.name, + op = r.op, + left = r.right, + right = r.left, + patvars = r.patvars, + ematcher_left! = r.ematcher_right!, + ematcher_right! = r.ematcher_left!, + matcher_left = r.matcher_right, + matcher_right = r.matcher_left, + lhs_original = r.rhs_original, + rhs_original = r.lhs_original, + ) +end + +""" +Turns an EqualityRule into a DirectedRule. For example, +```julia +direct(@rule f(~x) == g(~x)) == f(~x) --> g(~x) +``` +""" +function direct(r::EqualityRule) + RewriteRule(r.name, -->, (getfield(r,k) for k in fieldnames(DirectedRule)[3:end])...) +end + +""" +Turns an EqualityRule into a DirectedRule, but right to left. For example, + +```julia +direct(@rule f(~x) == g(~x)) == g(~x) --> f(~x) +``` +""" +direct_right_to_left(r::EqualityRule) = invert(direct(r)) +direct_left_to_right(r::EqualityRule) = direct(r) end diff --git a/src/Syntax.jl b/src/Syntax.jl index 5cacd82b..2392911f 100644 --- a/src/Syntax.jl +++ b/src/Syntax.jl @@ -395,6 +395,7 @@ macro rule(args...) @assert pvars == ppvars ematcher_right_expr = :nothing + matcher_right_expr = :nothing rhs = rhs_original = :(println("replace me")) @@ -416,6 +417,7 @@ macro rule(args...) if op in (:(==), :(!=)) # Bidirectional rule ematcher_right_expr = esc(ematch_compile(rhs, pvars, -1)) + matcher_right_expr = esc(match_compile(rhs, pvars)) extravars = setdiff(pvars, patvars(lhs) ∩ patvars(rhs)) if !isempty(extravars) error("unbound pattern variables $extravars when creating bidirectional rule") @@ -438,6 +440,7 @@ macro rule(args...) ematcher_left! = $ematcher_left_expr, ematcher_right! = $ematcher_right_expr, matcher_left = $matcher_left_expr, + matcher_right = $matcher_right_expr, lhs_original = $(QuoteNode(l)), rhs_original = $(QuoteNode(rhs_original)), ) diff --git a/test/classic/reductions.jl b/test/classic/reductions.jl index fdb4a610..ab1b14c8 100644 --- a/test/classic/reductions.jl +++ b/test/classic/reductions.jl @@ -295,7 +295,6 @@ using Metatheory.Syntax: @capture r = (@capture x ~x) @test r == true end - module QuxTest using Metatheory, Test, TermInterface struct Qux diff --git a/test/unit/rules.jl b/test/unit/rules.jl index 3a037120..8989e4a2 100644 --- a/test/unit/rules.jl +++ b/test/unit/rules.jl @@ -30,4 +30,19 @@ end r = @rule Main.f(~~x) --> ~x r == eval(:(@rule $(Meta.parse(repr(r))))) -end \ No newline at end of file +end + + +@testset "EqualityRule to DirectedRule(s)" begin + r = @rule "distributive" x y z x*(y + z) == x*y + x*z + r_ltr = @rule "distributive" x y z x * (y + z) --> x*y + x*z + r_rtl = @rule "distributive" x y z x*y + x*z --> x * (y + z) + r1 = direct(r) + r2 = Metatheory.direct_right_to_left(r) + + @test r1 isa DirectedRule + @test r2 isa DirectedRule + @test repr(r1) == repr(r_ltr) + @test repr(r2) == repr(r_rtl) +end +