Skip to content

Commit

Permalink
refactor: use searchLCtx in sym_aggregate (#201)
Browse files Browse the repository at this point in the history
### Description:

Stacked on #200.

This uses the `searchLCtx` machinery we built in #189 in
`sym_aggregate`'s implementation.
For now, this searches for exactly the same kind of expressions as
before, but this makes it much easier to expand this set in the next PR.

### Testing:

What tests have been run? Did `make all` succeed for your changes? Was
conformance testing successful on an Aarch64 machine? yes

### License:

By submitting this pull request, I confirm that my contribution is
made under the terms of the Apache 2.0 license.

---------

Co-authored-by: Shilpi Goel <[email protected]>
  • Loading branch information
alexkeizer and shigoel authored Oct 3, 2024
1 parent 9431fcb commit fbae789
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 44 deletions.
77 changes: 40 additions & 37 deletions Tactics/Aggregate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Author(s): Alex Keizer, Siddharth Bhat
import Lean
import Tactics.Common
import Tactics.Simp
import Tactics.Sym.LCtxSearch

open Lean Meta Elab.Tactic

Expand Down Expand Up @@ -65,44 +66,46 @@ elab "sym_aggregate" simpConfig?:(config)? loc?:(location)? : tactic => withMain
let simpConfig? ← simpConfig?.mapM fun cfg =>
elabSimpConfig (mkNullNode #[cfg]) (kind := .simp)

let lctx ← getLCtx
-- We keep `expectedRead`/`expectedAlign` as monadic values,
-- so that we get new metavariables for each localdecl we check
let expectedRead : MetaM Expr := do
let fld ← mkFreshExprMVar (mkConst ``StateField)
let state ← mkFreshExprMVar mkArmState
let rhs ← mkFreshExprMVar none
mkEq (mkApp2 (mkConst ``r) fld state) rhs
let expectedReadMem : MetaM Expr := do
let n ← mkFreshExprMVar (mkConst ``Nat)
let addr ← mkFreshExprMVar (mkApp (mkConst ``BitVec) (toExpr 64))
let mem ← mkFreshExprMVar (mkConst ``Memory)
let rhs ← mkFreshExprMVar none
mkEq (mkApp3 (mkConst ``Memory.read_bytes) n addr mem) rhs
let expectedAlign : MetaM Expr := do
let state ← mkFreshExprMVar mkArmState
return mkApp (mkConst ``CheckSPAlignment) state
/-
We construct `axHyps` by running a `State` monad, which is
initialized with an empty array
-/
let ((), axHyps) ← StateT.run (s := #[]) <|
searchLCtx <| do
let whenFound := fun decl _ => do
-- Whenever a match is found, we add the corresponding declaration
-- to the `axHyps` array in the monadic state
modify (·.push decl)
return .continu

let axHyps ←
withTraceNode `Tactic.sym (fun _ => pure m!"searching for effect hypotheses") <|
lctx.foldlM (init := #[]) fun axHyps decl => do
forallTelescope decl.type <| fun _ type => do
trace[Tactic.sym] "checking {decl.toExpr} with type {type}"
let expectedRead ← expectedRead
let expectedAlign ← expectedAlign
let expectedReadMem ← expectedReadMem
if ← isDefEq type expectedRead then
trace[Tactic.sym] "{Lean.checkEmoji} match for {expectedRead}"
return axHyps.push decl
else if ← isDefEq type expectedAlign then
trace[Tactic.sym] "{Lean.checkEmoji} match for {expectedAlign}"
return axHyps.push decl
else if ← isDefEq type expectedReadMem then
trace[Tactic.sym] "{Lean.checkEmoji} match for {expectedReadMem}"
return axHyps.push decl
else
trace[Tactic.sym] "{Lean.crossEmoji} no match"
return axHyps
-- `r ?field ?state = ?rhs`
searchLCtxFor (whenFound := whenFound)
/- By matching under binders, this also matches for non-effect
hypotheses, which look like:
`∀ f, f ≠ _ → r f ?state = ?rhs`
-/
(matchUnderBinders := true)
(expectedType := do
let fld ← mkFreshExprMVar (mkConst ``StateField)
let state ← mkFreshExprMVar mkArmState
let rhs ← mkFreshExprMVar none
return mkEqReadField fld state rhs
)
-- `Memory.read_bytes ?n ?addr ?mem = ?rhs`
searchLCtxFor (whenFound := whenFound)
(matchUnderBinders := true)
(expectedType := do
let n ← mkFreshExprMVar (mkConst ``Nat)
let addr ← mkFreshExprMVar (mkApp (mkConst ``BitVec) (toExpr 64))
let mem ← mkFreshExprMVar (mkConst ``Memory)
let rhs ← mkFreshExprMVar none
mkEq (mkApp3 (mkConst ``Memory.read_bytes) n addr mem) rhs
)
-- `CheckSpAlignment ?state`
searchLCtxFor (whenFound := whenFound)
(expectedType := do
let state ← mkFreshExprMVar mkArmState
return mkApp (mkConst ``CheckSPAlignment) state)

let loc := (loc?.map expandLocation).getD (.targets #[] true)
aggregate axHyps loc simpConfig?
35 changes: 28 additions & 7 deletions Tactics/Sym/LCtxSearch.lean
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ structure LCtxSearchState.Pattern where
whenNotFound : Expr → m Unit
/-- Whether to change the type of successful matches -/
changeType : Bool
/-- Whether to match under binders -/
matchUnderBinders : Bool
/-- How many times have we (successfully) found the pattern -/
occurences : Nat := 0
/-- Whether the pattern is active; is `isActive = false`,
Expand Down Expand Up @@ -95,9 +97,14 @@ variable {m}
def-eq to the pattern, or if `whenFound` returned `skip` for all variables
that were found. For convenience, we pass in the `expectedType` here as well.
See `throwNotFound` for a convenient way to throw an error here.
- If `changeType` is set to true, then we change the type of every successful
match (i.e., `whenFound` returns `continu` or `done`) to be exactly the
`expectedType`
- If `changeType` (default `false`) is set to true, then we change the type of
every successful match (i.e., `whenFound` returns `continu` or `done`)
to be exactly the `expectedType`
- If `matchUnderBinders` (default `false`) is set to true, we will introduce
metavariable for any binders in a variable's type before matching.
For example, with `matchUnderBinders` set to true, we consider a variable
`h : ∀ f, r f s0 = r f s1`
as a match for expected type `r ?f s0 = r ?f s1`.
WARNING: Once a pattern is found for which `whenFound` returns `done`, that
particular variable will not be matched for any other patterns.
Expand All @@ -109,11 +116,12 @@ def searchLCtxFor
(whenFound : LocalDecl → Expr → m LCtxSearchResult)
(whenNotFound : Expr → m Unit := fun _ => pure ())
(changeType : Bool := false)
(matchUnderBinders : Bool := false)
: SearchLCtxForM m Unit := do
let pattern := {
-- Placeholder value, since we can't evaluate `m` inside of `LCtxSearchM`
cachedExpectedType :=← expectedType
expectedType, whenFound, whenNotFound, changeType
expectedType, whenFound, whenNotFound, changeType, matchUnderBinders
}
modify fun state => { state with
patterns := state.patterns.push pattern
Expand All @@ -128,28 +136,39 @@ def searchLCtxForOnce
(whenFound : LocalDecl → Expr → m Unit)
(whenNotFound : Expr → m Unit := fun _ => pure ())
(changeType : Bool := false)
(matchUnderBinders : Bool := false)
: SearchLCtxForM m Unit := do
searchLCtxFor (pure expectedType)
(fun d e => do whenFound d e; return .done)
whenNotFound changeType
whenNotFound changeType matchUnderBinders

section Run
open Elab.Tactic
open Meta (isDefEq)
variable [MonadLCtx m] [MonadLiftT MetaM m] [MonadLiftT TacticM m]

namespace LCtxSearchState

/-- Return `true` if `e` matches the pattern -/
def Pattern.matches (pat : Pattern m) (e : Expr) : m Bool := do
let mut e := e
if pat.matchUnderBinders then
let ⟨_, _, e'⟩ ← Meta.forallMetaTelescope e
e := e'
isDefEq e pat.cachedExpectedType

/--
Attempt to match `e` against the given pattern:
- if `e` is def-eq to `pat.cachedExpectedType`, then return
the updated pattern state (with a fresh `cachedExpectedType`), and
the result of `whenFound`
- Otherwise, if `e` is not def-eq, return `none`
-/
def LCtxSearchState.Pattern.match? (pat : Pattern m) (decl : LocalDecl) :
def Pattern.match? (pat : Pattern m) (decl : LocalDecl) :
m (Option (Pattern m × LCtxSearchResult)) := do
if !pat.isActive then
return none
else if !(← isDefEq decl.type pat.cachedExpectedType) then
else if !(← pat.matches decl.type) then
return none
else
let cachedExpectedType ← pat.expectedType
Expand All @@ -163,6 +182,8 @@ def LCtxSearchState.Pattern.match? (pat : Pattern m) (decl : LocalDecl) :
replaceMainGoal [goal]
return some ({pat with cachedExpectedType, occurences}, res)

end LCtxSearchState

/-- Search the local context for variables of certain types, in a single pass.
`k` is a monadic continuation that determines the patterns to search for,
see `searchLCtxFor` to see how to register those patterns
Expand Down

0 comments on commit fbae789

Please sign in to comment.