Skip to content

Commit

Permalink
add option for type canonicalization to LCtxSearch
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkeizer committed Sep 25, 2024
1 parent 840c170 commit 20173f9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 59 deletions.
62 changes: 10 additions & 52 deletions Tactics/Sym/Context.lean
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ private def initial (state : Expr) : MetaM SymContext := do
effects := AxEffects.initial state
}


/-- Infer `state_prefix` and `curr_state_number` from the `state` name
as follows: if `state` is `s{i}` for some number `i` and a single character `s`,
then `s` is the prefix and `i` the number,
Expand Down Expand Up @@ -365,6 +364,7 @@ protected def searchFor : SearchLCtxForM SymM Unit := do
-- Find `h_pc : r .PC <s> = <pc>`
let pc ← mkFreshExprMVar (← mkAppM ``BitVec #[toExpr 64])
searchLCtxForOnce (h_pc_type initialState pc)
(changeType := true)
(whenNotFound := throwNotFound)
(whenFound := fun decl _ => do
let pc ← instantiateMVars pc
Expand All @@ -381,6 +381,7 @@ protected def searchFor : SearchLCtxForM SymM Unit := do

-- Find `h_err : r .ERR <s> = .None`, or add a new subgoal if it isn't found
searchLCtxForOnce (h_err_type initialState)
(changeType := true)
(whenFound := fun decl _ =>
AxEffects.setErrorProof decl.toExpr
)
Expand All @@ -392,10 +393,11 @@ protected def searchFor : SearchLCtxForM SymM Unit := do

-- Find `h_sp : CheckSPAlignment <initialState>`.
searchLCtxForOnce (h_sp_type initialState)
(changeType := true)
(whenNotFound := traceNotFound `Tactic.sym)
-- ^^ Note that `h_sp` is optional, so there's no need to throw an error,
-- we merely add a message to the trace and move on
(whenFound := fun decl _ =>
(whenFound := fun decl _ => do
modifyThe AxEffects ({ · with
stackAlignmentProof? := some decl.toExpr
})
Expand All @@ -422,7 +424,7 @@ we create a new subgoal of this type
-/
def fromLocalContext (state? : Option Name) : TacticM SymContext := do
let msg := m!"Building a `SymContext` from the local context"
withTraceNode `Tactic.sym (fun _ => pure msg) do
withTraceNode `Tactic.sym (fun _ => pure msg) <| withMainContext' do
trace[Tactic.Sym] "state? := {state?}"
let lctx ← getLCtx

Expand All @@ -440,55 +442,11 @@ def fromLocalContext (state? : Option Name) : TacticM SymContext := do
c.modify <| do
searchLCtx SymContext.searchFor

let thms ← (← readThe AxEffects).toSimpTheorems
modifyThe SymContext (·.addSimpTheorems thms)

inferStatePrefixAndNumber
where
findLocalDeclOfType? (expectedType : Expr) : MetaM (Option LocalDecl) := do
let msg := m!"Searching for hypothesis of type: {expectedType}"
withTraceNode `Tactic.sym (fun _ => pure msg) <| do
let decl? ← _root_.findLocalDeclOfType? expectedType
trace[Tactic.sym] "Found: {(·.toExpr) <$> decl?}"
return decl?
findLocalDeclOfTypeOrError (expectedType : Expr) : MetaM LocalDecl := do
let msg := m!"Searching for hypothesis of type: {expectedType}"
withTraceNode `Tactic.sym (fun _ => pure msg) <| do
let decl ← _root_.findLocalDeclOfTypeOrError expectedType
trace[Tactic.sym] "Found: {decl.toExpr}"
return decl

/-! ## Massaging the local context -/

/-- change the type (in the local context of the main goal)
of the hypotheses tracked by the given `SymContext` to be *exactly* of the shape
described in the relevant docstrings.
That is, (un)fold types which were definitionally, but not syntactically,
equal to the expected shape. -/
def canonicalizeHypothesisTypes : SymReaderM Unit := withMainContext' do
let c ← readThe SymContext
let lctx ← getLCtx
let mut goal ← getMainGoal
let state := c.effects.currentState

let mut hyps := #[]
if let some runSteps := c.runSteps? then
hyps := hyps.push (c.h_run, h_run_type c.finalState (toExpr runSteps) state)
if let some h_sp := c.h_sp? then
hyps := hyps.push (h_sp, h_sp_type state)

let mut hypIds ← hyps.mapM fun ⟨name, type⟩ => do
let some decl := lctx.findFromUserName? name
| throwError "Unknown local hypothesis `{name}`"
pure (decl.fvarId, type)

let errHyp ← AxEffects.getFieldM .ERR
if let Expr.fvar id := errHyp.proof then
hypIds := hypIds.push (id, h_err_type state)
for ⟨fvarId, type⟩ in hypIds do
goal ← goal.replaceLocalDeclDefEq fvarId type
replaceMainGoal [goal]
withMainContext' <| do
let thms ← (← readThe AxEffects).toSimpTheorems
modifyThe SymContext (·.addSimpTheorems thms)

inferStatePrefixAndNumber

/-! ## Incrementing the context to the next state -/

Expand Down
28 changes: 21 additions & 7 deletions Tactics/Sym/LCtxSearch.lean
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ structure LCtxSearchState.Pattern where
(as determined by the return value of `whenFound`)
could be found in the local context -/
whenNotFound : Expr → m Unit
/-- Whether to change the type of successful matches -/
changeType : Bool
/-- How many times have we (successfully) found the pattern -/
occurences : Nat := 0
/-- Whether the pattern is active; is `isActive = false`,
Expand Down Expand Up @@ -93,6 +95,9 @@ variable {m}
def-eq to the pattern, or if `whenFound` returned `skip` for all variables
that were found. For convenience, we pass in the `expectedType` here as well.
See `throwNotFound` for a convenient way to throw an error here.
- If `changeType` is set to true, then we change the type of every successful
match (i.e., `whenFound` returns `continu` or `done`) to be exactly the
`expectedType`
WARNING: Once a pattern is found for which `whenFound` returns `done`, that
particular variable will not be matched for any other patterns.
Expand All @@ -103,30 +108,35 @@ def searchLCtxFor
(expectedType : m Expr)
(whenFound : LocalDecl → Expr → m LCtxSearchResult)
(whenNotFound : Expr → m Unit := fun _ => pure ())
(changeType : Bool := false)
: SearchLCtxForM m Unit := do
let pattern := {
-- Placeholder value, since we can't evaluate `m` inside of `LCtxSearchM`
cachedExpectedType :=← expectedType
expectedType, whenFound, whenNotFound
expectedType, whenFound, whenNotFound, changeType
}
modify fun state => { state with
patterns := state.patterns.push pattern
}

/-- A wrapper around `searchLCtxFor`, which is simplified for matching at most
one occurence of `expectedType` -/
one occurence of `expectedType`.
See `searchLCtxFor` for an explanation of the arguments -/
def searchLCtxForOnce
(expectedType : Expr)
(whenFound : LocalDecl → Expr → m Unit)
(whenNotFound : Expr → m Unit := fun _ => pure ())
(changeType : Bool := false)
: SearchLCtxForM m Unit := do
searchLCtxFor (pure expectedType)
(fun d e => do whenFound d e; return .done)
whenNotFound
whenNotFound changeType

section Run
open Elab.Tactic
open Meta (isDefEq)
variable [MonadLCtx m] [MonadLiftT MetaM m]
variable [MonadLCtx m] [MonadLiftT MetaM m] [MonadLiftT TacticM m]

/--
Attempt to match `e` against the given pattern:
Expand All @@ -144,9 +154,13 @@ def LCtxSearchState.Pattern.match? (pat : Pattern m) (decl : LocalDecl) :
else
let cachedExpectedType ← pat.expectedType
let res ← pat.whenFound decl pat.cachedExpectedType
let occurences := match res with
| .skip => pat.occurences
| .done | .continu => pat.occurences + 1
let mut occurences := pat.occurences
if res != .skip then
occurences := occurences + 1
if pat.changeType = true then
let goal ← getMainGoal
let goal ← goal.replaceLocalDeclDefEq decl.fvarId pat.cachedExpectedType
replaceMainGoal [goal]
return some ({pat with cachedExpectedType, occurences}, res)

/-- Search the local context for variables of certain types, in a single pass.
Expand Down

0 comments on commit 20173f9

Please sign in to comment.