diff --git a/Tactics/Sym/Context.lean b/Tactics/Sym/Context.lean index b600aed8..1505f4e8 100644 --- a/Tactics/Sym/Context.lean +++ b/Tactics/Sym/Context.lean @@ -271,7 +271,6 @@ private def initial (state : Expr) : MetaM SymContext := do effects := AxEffects.initial state } - /-- Infer `state_prefix` and `curr_state_number` from the `state` name as follows: if `state` is `s{i}` for some number `i` and a single character `s`, then `s` is the prefix and `i` the number, @@ -365,6 +364,7 @@ protected def searchFor : SearchLCtxForM SymM Unit := do -- Find `h_pc : r .PC = ` let pc ← mkFreshExprMVar (← mkAppM ``BitVec #[toExpr 64]) searchLCtxForOnce (h_pc_type initialState pc) + (changeType := true) (whenNotFound := throwNotFound) (whenFound := fun decl _ => do let pc ← instantiateMVars pc @@ -381,6 +381,7 @@ protected def searchFor : SearchLCtxForM SymM Unit := do -- Find `h_err : r .ERR = .None`, or add a new subgoal if it isn't found searchLCtxForOnce (h_err_type initialState) + (changeType := true) (whenFound := fun decl _ => AxEffects.setErrorProof decl.toExpr ) @@ -392,10 +393,11 @@ protected def searchFor : SearchLCtxForM SymM Unit := do -- Find `h_sp : CheckSPAlignment `. searchLCtxForOnce (h_sp_type initialState) + (changeType := true) (whenNotFound := traceNotFound `Tactic.sym) -- ^^ Note that `h_sp` is optional, so there's no need to throw an error, -- we merely add a message to the trace and move on - (whenFound := fun decl _ => + (whenFound := fun decl _ => do modifyThe AxEffects ({ · with stackAlignmentProof? := some decl.toExpr }) @@ -422,7 +424,7 @@ we create a new subgoal of this type -/ def fromLocalContext (state? : Option Name) : TacticM SymContext := do let msg := m!"Building a `SymContext` from the local context" - withTraceNode `Tactic.sym (fun _ => pure msg) do + withTraceNode `Tactic.sym (fun _ => pure msg) <| withMainContext' do trace[Tactic.Sym] "state? := {state?}" let lctx ← getLCtx @@ -440,55 +442,11 @@ def fromLocalContext (state? : Option Name) : TacticM SymContext := do c.modify <| do searchLCtx SymContext.searchFor - let thms ← (← readThe AxEffects).toSimpTheorems - modifyThe SymContext (·.addSimpTheorems thms) - - inferStatePrefixAndNumber -where - findLocalDeclOfType? (expectedType : Expr) : MetaM (Option LocalDecl) := do - let msg := m!"Searching for hypothesis of type: {expectedType}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do - let decl? ← _root_.findLocalDeclOfType? expectedType - trace[Tactic.sym] "Found: {(·.toExpr) <$> decl?}" - return decl? - findLocalDeclOfTypeOrError (expectedType : Expr) : MetaM LocalDecl := do - let msg := m!"Searching for hypothesis of type: {expectedType}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| do - let decl ← _root_.findLocalDeclOfTypeOrError expectedType - trace[Tactic.sym] "Found: {decl.toExpr}" - return decl - -/-! ## Massaging the local context -/ - -/-- change the type (in the local context of the main goal) -of the hypotheses tracked by the given `SymContext` to be *exactly* of the shape -described in the relevant docstrings. - -That is, (un)fold types which were definitionally, but not syntactically, -equal to the expected shape. -/ -def canonicalizeHypothesisTypes : SymReaderM Unit := withMainContext' do - let c ← readThe SymContext - let lctx ← getLCtx - let mut goal ← getMainGoal - let state := c.effects.currentState - - let mut hyps := #[] - if let some runSteps := c.runSteps? then - hyps := hyps.push (c.h_run, h_run_type c.finalState (toExpr runSteps) state) - if let some h_sp := c.h_sp? then - hyps := hyps.push (h_sp, h_sp_type state) - - let mut hypIds ← hyps.mapM fun ⟨name, type⟩ => do - let some decl := lctx.findFromUserName? name - | throwError "Unknown local hypothesis `{name}`" - pure (decl.fvarId, type) - - let errHyp ← AxEffects.getFieldM .ERR - if let Expr.fvar id := errHyp.proof then - hypIds := hypIds.push (id, h_err_type state) - for ⟨fvarId, type⟩ in hypIds do - goal ← goal.replaceLocalDeclDefEq fvarId type - replaceMainGoal [goal] + withMainContext' <| do + let thms ← (← readThe AxEffects).toSimpTheorems + modifyThe SymContext (·.addSimpTheorems thms) + + inferStatePrefixAndNumber /-! ## Incrementing the context to the next state -/ diff --git a/Tactics/Sym/LCtxSearch.lean b/Tactics/Sym/LCtxSearch.lean index 34e33d72..58d4c3ea 100644 --- a/Tactics/Sym/LCtxSearch.lean +++ b/Tactics/Sym/LCtxSearch.lean @@ -59,6 +59,8 @@ structure LCtxSearchState.Pattern where (as determined by the return value of `whenFound`) could be found in the local context -/ whenNotFound : Expr → m Unit + /-- Whether to change the type of successful matches -/ + changeType : Bool /-- How many times have we (successfully) found the pattern -/ occurences : Nat := 0 /-- Whether the pattern is active; is `isActive = false`, @@ -93,6 +95,9 @@ variable {m} def-eq to the pattern, or if `whenFound` returned `skip` for all variables that were found. For convenience, we pass in the `expectedType` here as well. See `throwNotFound` for a convenient way to throw an error here. +- If `changeType` is set to true, then we change the type of every successful + match (i.e., `whenFound` returns `continu` or `done`) to be exactly the + `expectedType` WARNING: Once a pattern is found for which `whenFound` returns `done`, that particular variable will not be matched for any other patterns. @@ -103,30 +108,35 @@ def searchLCtxFor (expectedType : m Expr) (whenFound : LocalDecl → Expr → m LCtxSearchResult) (whenNotFound : Expr → m Unit := fun _ => pure ()) + (changeType : Bool := false) : SearchLCtxForM m Unit := do let pattern := { -- Placeholder value, since we can't evaluate `m` inside of `LCtxSearchM` cachedExpectedType :=← expectedType - expectedType, whenFound, whenNotFound + expectedType, whenFound, whenNotFound, changeType } modify fun state => { state with patterns := state.patterns.push pattern } /-- A wrapper around `searchLCtxFor`, which is simplified for matching at most -one occurence of `expectedType` -/ +one occurence of `expectedType`. + +See `searchLCtxFor` for an explanation of the arguments -/ def searchLCtxForOnce (expectedType : Expr) (whenFound : LocalDecl → Expr → m Unit) (whenNotFound : Expr → m Unit := fun _ => pure ()) + (changeType : Bool := false) : SearchLCtxForM m Unit := do searchLCtxFor (pure expectedType) (fun d e => do whenFound d e; return .done) - whenNotFound + whenNotFound changeType section Run +open Elab.Tactic open Meta (isDefEq) -variable [MonadLCtx m] [MonadLiftT MetaM m] +variable [MonadLCtx m] [MonadLiftT MetaM m] [MonadLiftT TacticM m] /-- Attempt to match `e` against the given pattern: @@ -144,9 +154,13 @@ def LCtxSearchState.Pattern.match? (pat : Pattern m) (decl : LocalDecl) : else let cachedExpectedType ← pat.expectedType let res ← pat.whenFound decl pat.cachedExpectedType - let occurences := match res with - | .skip => pat.occurences - | .done | .continu => pat.occurences + 1 + let mut occurences := pat.occurences + if res != .skip then + occurences := occurences + 1 + if pat.changeType = true then + let goal ← getMainGoal + let goal ← goal.replaceLocalDeclDefEq decl.fvarId pat.cachedExpectedType + replaceMainGoal [goal] return some ({pat with cachedExpectedType, occurences}, res) /-- Search the local context for variables of certain types, in a single pass.