Skip to content

Commit

Permalink
refactor: incorporate an AxEffects field in SymContext, introduce `Sy…
Browse files Browse the repository at this point in the history
…mM` monad for `SymContext` state (#180)

### Description:

This turned into a rather big tech-debt removal PR. The primary focus is
the removal of the duplication we had between `SymContext` tracking
names of hypotheses, and `AxEffects` tracking `Expr`s for those same
hypotheses.

- Added an `effects : AxEffects` field to `SymContext`, which stores the
`AxEffects` for a single (non-aggregated!) step
- To make live easier, we introduce a `SymM` monad, so that we don't
have to project out to the effect field every single every time
(credits/blame go to @bollu for showing me the `MonadStateOf` trick).
- This allowed us to remove a bunch of fields of `SymContext` which had
duplicates in `AxEffects`
- Also, it allowed us to move reflection of the PC out of `explodeSteps`
and into `prepareForNextStep` (which was previously called
`SymContext.next`),
- Finally, we extract `ensureAtMostRunSteps` and
`assertStepTheoremsGenerated` functions out of the main body of `sym_n`

### Testing:

What tests have been run? Did `make all` succeed for your changes? Was
conformance testing successful on an Aarch64 machine? yes

### License:

By submitting this pull request, I confirm that my contribution is
made under the terms of the Apache 2.0 license.

---------

Co-authored-by: Shilpi Goel <[email protected]>
  • Loading branch information
alexkeizer and shigoel authored Sep 25, 2024
1 parent 27a2f3b commit 26b5da0
Show file tree
Hide file tree
Showing 4 changed files with 385 additions and 232 deletions.
13 changes: 11 additions & 2 deletions Tactics/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ def findLocalDeclOfType? (expectedType : Expr) : MetaM (Option LocalDecl) := do
-- the local context, so we can safely pass it to `get!`

def findLocalDeclOfTypeOrError (expectedType : Expr) : MetaM LocalDecl := do
let some name ← findLocalDeclOfType? expectedType
let some decl ← findLocalDeclOfType? expectedType
| throwError "Failed to find a local hypothesis of type {expectedType}"
return name
return decl

/-- `findProgramHyp` searches the local context for an hypothesis of type
`state.program = ?concreteProgram`,
Expand Down Expand Up @@ -269,3 +269,12 @@ def traceHeartbeats (cls : Name) (header : Option String := none) :
let percent ← heartbeatsPercent
trace cls fun _ =>
m!"{header}used {heartbeats} heartbeats ({percent}% of maximum)"

/-! ## `withMainContext'` -/

variable {m} [Monad m] [MonadLiftT TacticM m] [MonadControlT MetaM m] in
/-- Execute `x` using the main goal local context and instances.
Unlike the standard `withMainContext`, `x` may live in a generic monad `m`. -/
def withMainContext' (x : m α) : m α := do
(← getMainGoal).withContext x
61 changes: 59 additions & 2 deletions Tactics/Reflect/AxEffects.lean
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,37 @@ structure AxEffects where

namespace AxEffects

/-! ## Monad getters -/

section Monad
variable {m} [Monad m] [MonadReaderOf AxEffects m]

def getCurrentState : m Expr := do return (← read).currentState
def getInitialState : m Expr := do return (← read).initialState
def getNonEffectProof : m Expr := do return (← read).nonEffectProof
def getMemoryEffect : m Expr := do return (← read).memoryEffect
def getMemoryEffectProof : m Expr := do return (← read).memoryEffectProof
def getProgramProof : m Expr := do return (← read).programProof

def getStackAlignmentProof? : m (Option Expr) := do
return (← read).stackAlignmentProof?

variable [MonadLiftT MetaM m] in
/-- Retrieve the user-facing name of the current state, assuming that
the current state is a free variable in the ambient local context -/
def getCurrentStateName : m Name := do
let state ← getCurrentState
@id (MetaM _) <| do
let state ← instantiateMVars state
let Expr.fvar id := state.consumeMData
| throwError "error: expected a free variable, found:\n {state} WHHOPS"
let lctx ← getLCtx
let some decl := lctx.find? id
| throwError "error: unknown fvar: {state}"
return decl.userName

end Monad

/-! ## Initial Reflected State -/

/-- An initial `AxEffects` state which has no writes.
Expand Down Expand Up @@ -185,7 +216,7 @@ partial def mkAppNonEffect (eff : AxEffects) (field : Expr) : MetaM Expr := do
return nonEffectProof

/-- Get the value for a field, if one is stored in `eff.fields`,
or assemble an instantiation of the non-effects proof -/
or assemble an instantiation of the non-effects proof otherwise -/
def getField (eff : AxEffects) (fld : StateField) : MetaM FieldEffect :=
let msg := m!"getField {fld}"
withTraceNode `Tactic.sym (fun _ => pure msg) <| do
Expand All @@ -200,6 +231,11 @@ def getField (eff : AxEffects) (fld : StateField) : MetaM FieldEffect :=
let proof ← eff.mkAppNonEffect (toExpr fld)
pure { value, proof }

variable {m} [Monad m] [MonadReaderOf AxEffects m] [MonadLiftT MetaM m] in
@[inherit_doc getField]
def getFieldM (field : StateField) : m FieldEffect := do
(← read).getField field

/-! ## Update a Reflected State -/

/-- Execute `write_mem <n> <addr> <val>` against the state stored in `eff`
Expand Down Expand Up @@ -359,6 +395,24 @@ private def assertIsDefEq (e expected : Expr) : MetaM Unit := do
if !(←isDefEq e expected) then
throwError "expected:\n {expected}\nbut found:\n {e}"

/-- Given an expression `e : ArmState`,
which is a sequence of `w`/`write_mem`s to `eff.currentState`,
return an `AxEffects` where `e` is the new `currentState`. -/
partial def updateWithExpr (eff : AxEffects) (e : Expr) : MetaM AxEffects := do
let msg := m!"Updating effects with writes from: {e}"
withTraceNode `Tactic.sym (fun _ => pure msg) <| do match_expr e with
| write_mem_bytes n addr val e =>
let eff ← eff.updateWithExpr e
eff.update_write_mem n addr val

| w field value e =>
let eff ← eff.updateWithExpr e
eff.update_w field value

| _ =>
assertIsDefEq e eff.currentState
return eff

/-- Given an expression `e : ArmState`,
which is a sequence of `w`/`write_mem`s to the some state `s`,
return an `AxEffects` where `s` is the intial state, and `e` is `currentState`.
Expand Down Expand Up @@ -466,7 +520,10 @@ def withField (eff : AxEffects) (eq : Expr) : MetaM AxEffects := do
trace[Tactic.sym] "current field effect: {fieldEff}"

if field ∉ eff.fields then
let proof ← mkEqTrans fieldEff.proof eq
let proof ← if eff.currentState == eff.initialState then
pure eq
else
mkEqTrans fieldEff.proof eq
let fields := eff.fields.insert field { value, proof }
return { eff with fields }
else
Expand Down
Loading

0 comments on commit 26b5da0

Please sign in to comment.