diff --git a/Tactics/Aggregate.lean b/Tactics/Aggregate.lean index cdb2f6ca..19afad66 100644 --- a/Tactics/Aggregate.lean +++ b/Tactics/Aggregate.lean @@ -6,6 +6,7 @@ Author(s): Alex Keizer, Siddharth Bhat import Lean import Tactics.Common import Tactics.Simp +import Tactics.Sym.LCtxSearch open Lean Meta Elab.Tactic @@ -65,44 +66,46 @@ elab "sym_aggregate" simpConfig?:(config)? loc?:(location)? : tactic => withMain let simpConfig? ← simpConfig?.mapM fun cfg => elabSimpConfig (mkNullNode #[cfg]) (kind := .simp) - let lctx ← getLCtx - -- We keep `expectedRead`/`expectedAlign` as monadic values, - -- so that we get new metavariables for each localdecl we check - let expectedRead : MetaM Expr := do - let fld ← mkFreshExprMVar (mkConst ``StateField) - let state ← mkFreshExprMVar mkArmState - let rhs ← mkFreshExprMVar none - mkEq (mkApp2 (mkConst ``r) fld state) rhs - let expectedReadMem : MetaM Expr := do - let n ← mkFreshExprMVar (mkConst ``Nat) - let addr ← mkFreshExprMVar (mkApp (mkConst ``BitVec) (toExpr 64)) - let mem ← mkFreshExprMVar (mkConst ``Memory) - let rhs ← mkFreshExprMVar none - mkEq (mkApp3 (mkConst ``Memory.read_bytes) n addr mem) rhs - let expectedAlign : MetaM Expr := do - let state ← mkFreshExprMVar mkArmState - return mkApp (mkConst ``CheckSPAlignment) state + /- + We construct `axHyps` by running a `State` monad, which is + initialized with an empty array + -/ + let ((), axHyps) ← StateT.run (s := #[]) <| + searchLCtx <| do + let whenFound := fun decl _ => do + -- Whenever a match is found, we add the corresponding declaration + -- to the `axHyps` array in the monadic state + modify (·.push decl) + return .continu - let axHyps ← - withTraceNode `Tactic.sym (fun _ => pure m!"searching for effect hypotheses") <| - lctx.foldlM (init := #[]) fun axHyps decl => do - forallTelescope decl.type <| fun _ type => do - trace[Tactic.sym] "checking {decl.toExpr} with type {type}" - let expectedRead ← expectedRead - let expectedAlign ← expectedAlign - let expectedReadMem ← expectedReadMem - if ← isDefEq type expectedRead then - trace[Tactic.sym] "{Lean.checkEmoji} match for {expectedRead}" - return axHyps.push decl - else if ← isDefEq type expectedAlign then - trace[Tactic.sym] "{Lean.checkEmoji} match for {expectedAlign}" - return axHyps.push decl - else if ← isDefEq type expectedReadMem then - trace[Tactic.sym] "{Lean.checkEmoji} match for {expectedReadMem}" - return axHyps.push decl - else - trace[Tactic.sym] "{Lean.crossEmoji} no match" - return axHyps + -- `r ?field ?state = ?rhs` + searchLCtxFor (whenFound := whenFound) + /- By matching under binders, this also matches for non-effect + hypotheses, which look like: + `∀ f, f ≠ _ → r f ?state = ?rhs` + -/ + (matchUnderBinders := true) + (expectedType := do + let fld ← mkFreshExprMVar (mkConst ``StateField) + let state ← mkFreshExprMVar mkArmState + let rhs ← mkFreshExprMVar none + return mkEqReadField fld state rhs + ) + -- `Memory.read_bytes ?n ?addr ?mem = ?rhs` + searchLCtxFor (whenFound := whenFound) + (matchUnderBinders := true) + (expectedType := do + let n ← mkFreshExprMVar (mkConst ``Nat) + let addr ← mkFreshExprMVar (mkApp (mkConst ``BitVec) (toExpr 64)) + let mem ← mkFreshExprMVar (mkConst ``Memory) + let rhs ← mkFreshExprMVar none + mkEq (mkApp3 (mkConst ``Memory.read_bytes) n addr mem) rhs + ) + -- `CheckSpAlignment ?state` + searchLCtxFor (whenFound := whenFound) + (expectedType := do + let state ← mkFreshExprMVar mkArmState + return mkApp (mkConst ``CheckSPAlignment) state) let loc := (loc?.map expandLocation).getD (.targets #[] true) aggregate axHyps loc simpConfig? diff --git a/Tactics/Sym/LCtxSearch.lean b/Tactics/Sym/LCtxSearch.lean index 58d4c3ea..bc59eb5d 100644 --- a/Tactics/Sym/LCtxSearch.lean +++ b/Tactics/Sym/LCtxSearch.lean @@ -61,6 +61,8 @@ structure LCtxSearchState.Pattern where whenNotFound : Expr → m Unit /-- Whether to change the type of successful matches -/ changeType : Bool + /-- Whether to match under binders -/ + matchUnderBinders : Bool /-- How many times have we (successfully) found the pattern -/ occurences : Nat := 0 /-- Whether the pattern is active; is `isActive = false`, @@ -95,9 +97,14 @@ variable {m} def-eq to the pattern, or if `whenFound` returned `skip` for all variables that were found. For convenience, we pass in the `expectedType` here as well. See `throwNotFound` for a convenient way to throw an error here. -- If `changeType` is set to true, then we change the type of every successful - match (i.e., `whenFound` returns `continu` or `done`) to be exactly the - `expectedType` +- If `changeType` (default `false`) is set to true, then we change the type of + every successful match (i.e., `whenFound` returns `continu` or `done`) + to be exactly the `expectedType` +- If `matchUnderBinders` (default `false`) is set to true, we will introduce + metavariable for any binders in a variable's type before matching. + For example, with `matchUnderBinders` set to true, we consider a variable + `h : ∀ f, r f s0 = r f s1` + as a match for expected type `r ?f s0 = r ?f s1`. WARNING: Once a pattern is found for which `whenFound` returns `done`, that particular variable will not be matched for any other patterns. @@ -109,11 +116,12 @@ def searchLCtxFor (whenFound : LocalDecl → Expr → m LCtxSearchResult) (whenNotFound : Expr → m Unit := fun _ => pure ()) (changeType : Bool := false) + (matchUnderBinders : Bool := false) : SearchLCtxForM m Unit := do let pattern := { -- Placeholder value, since we can't evaluate `m` inside of `LCtxSearchM` cachedExpectedType :=← expectedType - expectedType, whenFound, whenNotFound, changeType + expectedType, whenFound, whenNotFound, changeType, matchUnderBinders } modify fun state => { state with patterns := state.patterns.push pattern @@ -128,16 +136,27 @@ def searchLCtxForOnce (whenFound : LocalDecl → Expr → m Unit) (whenNotFound : Expr → m Unit := fun _ => pure ()) (changeType : Bool := false) + (matchUnderBinders : Bool := false) : SearchLCtxForM m Unit := do searchLCtxFor (pure expectedType) (fun d e => do whenFound d e; return .done) - whenNotFound changeType + whenNotFound changeType matchUnderBinders section Run open Elab.Tactic open Meta (isDefEq) variable [MonadLCtx m] [MonadLiftT MetaM m] [MonadLiftT TacticM m] +namespace LCtxSearchState + +/-- Return `true` if `e` matches the pattern -/ +def Pattern.matches (pat : Pattern m) (e : Expr) : m Bool := do + let mut e := e + if pat.matchUnderBinders then + let ⟨_, _, e'⟩ ← Meta.forallMetaTelescope e + e := e' + isDefEq e pat.cachedExpectedType + /-- Attempt to match `e` against the given pattern: - if `e` is def-eq to `pat.cachedExpectedType`, then return @@ -145,11 +164,11 @@ Attempt to match `e` against the given pattern: the result of `whenFound` - Otherwise, if `e` is not def-eq, return `none` -/ -def LCtxSearchState.Pattern.match? (pat : Pattern m) (decl : LocalDecl) : +def Pattern.match? (pat : Pattern m) (decl : LocalDecl) : m (Option (Pattern m × LCtxSearchResult)) := do if !pat.isActive then return none - else if !(← isDefEq decl.type pat.cachedExpectedType) then + else if !(← pat.matches decl.type) then return none else let cachedExpectedType ← pat.expectedType @@ -163,6 +182,8 @@ def LCtxSearchState.Pattern.match? (pat : Pattern m) (decl : LocalDecl) : replaceMainGoal [goal] return some ({pat with cachedExpectedType, occurences}, res) +end LCtxSearchState + /-- Search the local context for variables of certain types, in a single pass. `k` is a monadic continuation that determines the patterns to search for, see `searchLCtxFor` to see how to register those patterns