diff --git a/.github/workflows/copyright-header.yml b/.github/workflows/copyright-header.yml new file mode 100644 index 00000000..1863a250 --- /dev/null +++ b/.github/workflows/copyright-header.yml @@ -0,0 +1,20 @@ +name: Check for copyright header + +on: [pull_request] + +jobs: + check-lean-files: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Verify .lean files start with a copyright header. + run: | + FILES=$(find . -type d \( -path "./.lake" \) -prune -o -type f -name "*.lean" -exec perl -ne 'BEGIN { $/ = undef; } print "$ARGV\n" if !m{\A/-\nCopyright}; exit;' {} \;) + if [ -n "$FILES" ]; then + echo "Found .lean files which do not have a copyright header:" + echo "$FILES" + exit 1 + else + echo "All copyright headers present." + fi \ No newline at end of file diff --git a/Arm/Attr.lean b/Arm/Attr.lean index fe2c434f..d9076b78 100644 --- a/Arm/Attr.lean +++ b/Arm/Attr.lean @@ -1,3 +1,9 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Shilpi Goel +-/ + import Lean -- A minimal theory, safe for all LNSym proofs @@ -9,21 +15,3 @@ register_simp_attr state_simp_rules register_simp_attr bitvec_rules -- Rules for memory lemmas register_simp_attr memory_rules - -/- -syntax "state_simp" : tactic -macro_rules - | `(tactic| state_simp) => `(tactic| simp only [state_simp_rules]) - -syntax "state_simp?" : tactic -macro_rules - | `(tactic| state_simp?) => `(tactic| simp? only [state_simp_rules]) - -syntax "state_simp_all" : tactic -macro_rules - | `(tactic| state_simp_all) => `(tactic| simp_all only [state_simp_rules]) - -syntax "state_simp_all?" : tactic -macro_rules - | `(tactic| state_simp_all?) => `(tactic| simp_all? only [state_simp_rules]) --/ diff --git a/Arm/FromMathlib.lean b/Arm/FromMathlib.lean index 352537cf..3d17145f 100644 --- a/Arm/FromMathlib.lean +++ b/Arm/FromMathlib.lean @@ -1,3 +1,9 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +-/ + + -- This file has definitions temporarily lifted from Mathlib. -- We will move them into Lean shortly. diff --git a/Arm/MinTheory.lean b/Arm/MinTheory.lean index 602f47ee..a4cbc4fd 100644 --- a/Arm/MinTheory.lean +++ b/Arm/MinTheory.lean @@ -1,3 +1,9 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Shilpi Goel +-/ + import Arm.Attr -- These lemmas are from lean/Init/SimpLemmas.lean. diff --git a/Proofs/SHA512/SHA512.lean b/Proofs/SHA512/SHA512.lean index 03709f50..f044c44e 100644 --- a/Proofs/SHA512/SHA512.lean +++ b/Proofs/SHA512/SHA512.lean @@ -1,2 +1,7 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Shilpi Goel +-/ import Proofs.SHA512.SHA512_block_armv8_rules import Proofs.SHA512.SHA512Sym diff --git a/Tactics/Common.lean b/Tactics/Common.lean index 6bd98f26..ac1432ee 100644 --- a/Tactics/Common.lean +++ b/Tactics/Common.lean @@ -223,9 +223,9 @@ def findLocalDeclOfType? (expectedType : Expr) : MetaM (Option LocalDecl) := do -- the local context, so we can safely pass it to `get!` def findLocalDeclOfTypeOrError (expectedType : Expr) : MetaM LocalDecl := do - let some name ← findLocalDeclOfType? expectedType + let some decl ← findLocalDeclOfType? expectedType | throwError "Failed to find a local hypothesis of type {expectedType}" - return name + return decl /-- `findProgramHyp` searches the local context for an hypothesis of type `state.program = ?concreteProgram`, @@ -269,3 +269,12 @@ def traceHeartbeats (cls : Name) (header : Option String := none) : let percent ← heartbeatsPercent trace cls fun _ => m!"{header}used {heartbeats} heartbeats ({percent}% of maximum)" + +/-! ## `withMainContext'` -/ + +variable {m} [Monad m] [MonadLiftT TacticM m] [MonadControlT MetaM m] in +/-- Execute `x` using the main goal local context and instances. + +Unlike the standard `withMainContext`, `x` may live in a generic monad `m`. -/ +def withMainContext' (x : m α) : m α := do + (← getMainGoal).withContext x diff --git a/Tactics/Reflect/AxEffects.lean b/Tactics/Reflect/AxEffects.lean index 442e0b74..69cb5238 100644 --- a/Tactics/Reflect/AxEffects.lean +++ b/Tactics/Reflect/AxEffects.lean @@ -82,6 +82,37 @@ structure AxEffects where namespace AxEffects +/-! ## Monad getters -/ + +section Monad +variable {m} [Monad m] [MonadReaderOf AxEffects m] + +def getCurrentState : m Expr := do return (← read).currentState +def getInitialState : m Expr := do return (← read).initialState +def getNonEffectProof : m Expr := do return (← read).nonEffectProof +def getMemoryEffect : m Expr := do return (← read).memoryEffect +def getMemoryEffectProof : m Expr := do return (← read).memoryEffectProof +def getProgramProof : m Expr := do return (← read).programProof + +def getStackAlignmentProof? : m (Option Expr) := do + return (← read).stackAlignmentProof? + +variable [MonadLiftT MetaM m] in +/-- Retrieve the user-facing name of the current state, assuming that +the current state is a free variable in the ambient local context -/ +def getCurrentStateName : m Name := do + let state ← getCurrentState + @id (MetaM _) <| do + let state ← instantiateMVars state + let Expr.fvar id := state.consumeMData + | throwError "error: expected a free variable, found:\n {state} WHHOPS" + let lctx ← getLCtx + let some decl := lctx.find? id + | throwError "error: unknown fvar: {state}" + return decl.userName + +end Monad + /-! ## Initial Reflected State -/ /-- An initial `AxEffects` state which has no writes. @@ -185,7 +216,7 @@ partial def mkAppNonEffect (eff : AxEffects) (field : Expr) : MetaM Expr := do return nonEffectProof /-- Get the value for a field, if one is stored in `eff.fields`, -or assemble an instantiation of the non-effects proof -/ +or assemble an instantiation of the non-effects proof otherwise -/ def getField (eff : AxEffects) (fld : StateField) : MetaM FieldEffect := let msg := m!"getField {fld}" withTraceNode `Tactic.sym (fun _ => pure msg) <| do @@ -200,6 +231,11 @@ def getField (eff : AxEffects) (fld : StateField) : MetaM FieldEffect := let proof ← eff.mkAppNonEffect (toExpr fld) pure { value, proof } +variable {m} [Monad m] [MonadReaderOf AxEffects m] [MonadLiftT MetaM m] in +@[inherit_doc getField] +def getFieldM (field : StateField) : m FieldEffect := do + (← read).getField field + /-! ## Update a Reflected State -/ /-- Execute `write_mem ` against the state stored in `eff` @@ -359,6 +395,24 @@ private def assertIsDefEq (e expected : Expr) : MetaM Unit := do if !(←isDefEq e expected) then throwError "expected:\n {expected}\nbut found:\n {e}" +/-- Given an expression `e : ArmState`, +which is a sequence of `w`/`write_mem`s to `eff.currentState`, +return an `AxEffects` where `e` is the new `currentState`. -/ +partial def updateWithExpr (eff : AxEffects) (e : Expr) : MetaM AxEffects := do + let msg := m!"Updating effects with writes from: {e}" + withTraceNode `Tactic.sym (fun _ => pure msg) <| do match_expr e with + | write_mem_bytes n addr val e => + let eff ← eff.updateWithExpr e + eff.update_write_mem n addr val + + | w field value e => + let eff ← eff.updateWithExpr e + eff.update_w field value + + | _ => + assertIsDefEq e eff.currentState + return eff + /-- Given an expression `e : ArmState`, which is a sequence of `w`/`write_mem`s to the some state `s`, return an `AxEffects` where `s` is the intial state, and `e` is `currentState`. @@ -466,7 +520,10 @@ def withField (eff : AxEffects) (eq : Expr) : MetaM AxEffects := do trace[Tactic.sym] "current field effect: {fieldEff}" if field ∉ eff.fields then - let proof ← mkEqTrans fieldEff.proof eq + let proof ← if eff.currentState == eff.initialState then + pure eq + else + mkEqTrans fieldEff.proof eq let fields := eff.fields.insert field { value, proof } return { eff with fields } else diff --git a/Tactics/Sym.lean b/Tactics/Sym.lean index 8a7579bb..1b74546a 100644 --- a/Tactics/Sym.lean +++ b/Tactics/Sym.lean @@ -14,6 +14,8 @@ open BitVec open Lean Meta open Lean.Elab.Tactic +open AxEffects SymContext + /-- A wrapper around `evalTactic` that traces the passed tactic script, executes those tactics, and then traces the new goal state -/ private def evalTacticAndTrace (tactic : TSyntax `tactic) : TacticM Unit := @@ -42,25 +44,30 @@ macro "init_next_step" h_run:ident stepi_eq:ident sn:ident : tactic => section stepiTac -/-- Apply the relevant pre-generated stepi lemma to a local hypothesis +/-- Apply the relevant pre-generated stepi lemma to an expression `stepi_eq : stepi ?s = ?s'` -to obtain a new local hypothesis in terms of `w` and `write_mem` +to add a new local hypothesis in terms of `w` and `write_mem` `h_step : ?s' = w _ _ (w _ _ (... ?s))` -/ -def stepiTac (stepi_eq h_step : Ident) (ctx : SymContext) - : TacticM Unit := withMainContext do - let pc := (Nat.toDigits 16 ctx.pc.toNat).asString - -- ^^ The PC in hex - let step_lemma := mkIdent <| Name.str ctx.program s!"stepi_eq_0x{pc}" - - evalTacticAndTrace <|← `(tactic| ( - have $h_step := - _root_.Eq.trans (Eq.symm $stepi_eq) - ($step_lemma:ident - $ctx.h_program_ident:ident - $ctx.h_pc_ident:ident - $ctx.h_err_ident:ident) - )) +def stepiTac (stepiEq : Expr) (hStep : Name) : SymReaderM Unit := fun ctx => + withMainContext' do + let pc := (Nat.toDigits 16 ctx.pc.toNat).asString + -- ^^ The PC in hex + let stepLemma := Name.str ctx.program s!"stepi_eq_0x{pc}" + -- let stepLemma := Expr.const stepLemma [] + + let eff := ctx.effects + let hStepExpr ← mkEqTrans + (← mkEqSymm stepiEq) + (← mkAppM stepLemma #[ + eff.programProof, + (← eff.getField .PC).proof, + (← eff.getField .ERR).proof + ]) + + let goal ← getMainGoal + let ⟨_, goal⟩ ← goal.note hStep hStepExpr + replaceMainGoal [goal] end stepiTac @@ -80,8 +87,8 @@ for some metavariable `?runSteps`, then create the proof obligation `?runSteps = _ + 1`, and attempt to close it using `whileTac`. Finally, we use this proof to change the type of `h_run` accordingly. -/ -def unfoldRun (c : SymContext) (whileTac : Unit → TacticM Unit) : - TacticM Unit := +def unfoldRun (whileTac : Unit → TacticM Unit) : SymReaderM Unit := do + let c ← readThe SymContext let msg := m!"unfoldRun (runSteps? := {c.runSteps?})" withTraceNode `Tactic.sym (fun _ => pure msg) <| match c.runSteps? with @@ -94,7 +101,7 @@ def unfoldRun (c : SymContext) (whileTac : Unit → TacticM Unit) : -- NOTE: this error shouldn't occur, as we should have checked in -- `sym_n` that, if the number of runSteps is statically known, -- that we never simulate more than that many steps - | none => withMainContext do + | none => withMainContext' do let mut goal :: originalGoals ← getGoals | throwNoGoalsToBeSolved let hRunDecl ← c.hRunDecl @@ -104,7 +111,7 @@ def unfoldRun (c : SymContext) (whileTac : Unit → TacticM Unit) : guard <|← isDefEq hRunDecl.type ( mkApp3 (.const ``Eq [1]) (mkConst ``ArmState) c.finalState - (mkApp2 (mkConst ``_root_.run) runSteps (←c.stateExpr))) + (mkApp2 (mkConst ``_root_.run) runSteps (← getCurrentState))) -- NOTE: ^^ Since we check for def-eq on SymContext construction, -- this check should never fail @@ -132,8 +139,9 @@ def unfoldRun (c : SymContext) (whileTac : Unit → TacticM Unit) : runStepsPredId.assign default -- Change the type of `h_run` + let state ← getCurrentState let typeNew ← do - let rhs := mkApp2 (mkConst ``_root_.run) subGoalTyRhs (←c.stateExpr) + let rhs := mkApp2 (mkConst ``_root_.run) subGoalTyRhs state mkEq c.finalState rhs let eqProof ← do let f := -- `fun s => = s` @@ -141,9 +149,8 @@ def unfoldRun (c : SymContext) (whileTac : Unit → TacticM Unit) : c.finalState (.bvar 0) .lam `s (mkConst ``ArmState) eq .default let g := mkConst ``_root_.run - let s ← c.stateExpr let h ← instantiateMVars (.mvar subGoal) - mkCongrArg f (←mkCongrFun (←mkCongrArg g h) s) + mkCongrArg f (←mkCongrFun (←mkCongrArg g h) state) let res ← goal.replaceLocalDecl hRunDecl.fvarId typeNew eqProof -- Restore goal state @@ -153,18 +160,17 @@ def unfoldRun (c : SymContext) (whileTac : Unit → TacticM Unit) : originalGoals := originalGoals.concat subGoal setGoals (res.mvarId :: originalGoals) -/-- Given an equality `h_step : s{i+1} = w ... (... (w ... s{i})...)`, -add hypotheses that axiomatically describe the effects in terms of -reads from `s{i+1}`. - -Return the context for the next step (see `SymContext.next`), where -we attempt to determine the new PC by reflecting the obtained effects, -falling back to incrementing the PC if reflection failed. -/ -def explodeStep (c : SymContext) (hStep : Expr) : TacticM SymContext := - withMainContext do +/-- Break an equality `h_step : s{i+1} = w ... (... (w ... s{i})...)` into an +`AxEffects` that characterizes the effects in terms of reads from `s{i+1}`, +add the relevant hypotheses to the local context, and +store an `AxEffects` object with the newly added variables in the monad state +-/ +def explodeStep (hStep : Expr) : SymM Unit := + withMainContext' do + let c ← getThe SymContext let mut eff ← AxEffects.fromEq hStep - let stateExpr ← c.stateExpr + let stateExpr ← getCurrentState /- Assert that the initial state of the obtained `AxEffects` is equal to the state tracked by `c`. This will catch and throw an error if the semantics of the current @@ -173,11 +179,8 @@ def explodeStep (c : SymContext) (hStep : Expr) : TacticM SymContext := throwError "[explodeStep] expected initial state {stateExpr}, but found:\n \ {eff.initialState}\nin\n\n{eff}" - let hProgram ← SymContext.findFromUserName c.h_program - eff ← eff.withProgramEq hProgram.toExpr - - let hErr ← SymContext.findFromUserName c.h_err - eff ← eff.withField hErr.toExpr + eff ← eff.withProgramEq c.effects.programProof + eff ← eff.withField (← c.effects.getField .ERR).proof if let some h_sp := c.h_sp? then let hSp ← SymContext.findFromUserName h_sp @@ -198,12 +201,11 @@ def explodeStep (c : SymContext) (hStep : Expr) : TacticM SymContext := let subGoal ← mkFreshMVarId -- subGoal.setTag <| let hAligned ← do - let name := Name.mkSimple s!"h_{c.next_state}_sp_aligned" + let name := Name.mkSimple s!"h_{← getNextStateName}_sp_aligned" mkFreshExprMVarWithId subGoal (userName := name) <| mkAppN (mkConst ``Aligned) #[toExpr 64, spEff.value, toExpr 4] trace[Tactic.sym] "created subgoal to show alignment:\n{subGoal}" - let subGoal? ← do let (ctx, simprocs) ← LNSymSimpContext @@ -222,30 +224,16 @@ def explodeStep (c : SymContext) (hStep : Expr) : TacticM SymContext := #[eff.currentState, spEff.value, spEff.proof, hAligned] pure { eff with stackAlignmentProof? } - -- Add new (non-)effect hyps to the context - let simpThms ← withMainContext <| do + -- Add new (non-)effect hyps to the context, and to the aggregation simpset + withMainContext' <| do if ←(getBoolOption `Tactic.sym.debug) then eff.validate - let eff ← eff.addHypothesesToLContext s!"h_{c.next_state}_" - withMainContext <| eff.toSimpTheorems - - -- Add the new (non-)effect hyps to the aggregation simp context - let c := c.addSimpTheorems simpThms - - -- Attempt to reflect the new PC - let nextPc ← eff.getField .PC - let nextPc? ← try - let nextPc ← reflectBitVecLiteral 64 nextPc.value - -- NOTE: `reflectBitVecLiteral` is fast when the value is a literal, - -- but might involve an expensive reduction when it is not - pure <| some nextPc - catch err => - trace[Tactic.sym] "failed to reflect {nextPc.value}\n\n\ - {err.toMessageData}" - pure none - - return c.next nextPc? + let eff ← eff.addHypothesesToLContext s!"h_{← getNextStateName}_" + withMainContext' <| do + let simpThms ← eff.toSimpTheorems + modifyThe SymContext (·.addSimpTheorems simpThms) + set eff /-- A tactic wrapper around `explodeStep`. Note the use of `SymContext.fromLocalContext`, @@ -257,39 +245,43 @@ elab "explode_step" h_step:term " at " state:term : tactic => withMainContext do | throwError "Expected fvar, found {state}" let stateDecl := (← getLCtx).get! stateFVar let c ← SymContext.fromLocalContext (some stateDecl.userName) - - let _ ← explodeStep c hStep - + let _ ← SymM.run c <| explodeStep hStep /-- Symbolically simulate a single step, according the the symbolic simulation context `c`, returning the context for the next step in simulation. -/ -def sym1 (c : SymContext) (whileTac : TSyntax `tactic) : TacticM SymContext := - let msg := m!"(sym1): simulating step {c.curr_state_number}" - withTraceNode `Tactic.sym (fun _ => pure msg) <| withMainContext do +def sym1 (whileTac : TSyntax `tactic) : SymM Unit := do + let stateNumber ← getCurrentStateNumber + let msg := m!"(sym1): simulating step {stateNumber}" + withTraceNode `Tactic.sym (fun _ => pure msg) <| withMainContext' do withTraceNode `Tactic.sym (fun _ => pure "verbose context") <| do - trace[Tactic.sym] "SymContext:\n{← c.toMessageData}" + traceSymContext trace[Tactic.sym] "Goal state:\n {← getMainGoal}" - let stepi_eq := Lean.mkIdent (.mkSimple s!"stepi_{c.state}") - let h_step := Lean.mkIdent (.mkSimple s!"h_step_{c.curr_state_number + 1}") + let stepi_eq := Lean.mkIdent (.mkSimple s!"stepi_{← getCurrentStateName}") + let h_step := Lean.mkIdent (.mkSimple s!"h_step_{stateNumber + 1}") - unfoldRun c (fun _ => evalTacticAndTrace whileTac) + unfoldRun (fun _ => evalTacticAndTrace whileTac) -- Add new state to local context + let hRunId := mkIdent <|← getHRunName + let nextStateId := mkIdent <|← getNextStateName evalTacticAndTrace <|← `(tactic| - init_next_step $c.h_run_ident:ident $stepi_eq:ident $c.next_state_ident:ident + init_next_step $hRunId:ident $stepi_eq:ident $nextStateId:ident ) -- Apply relevant pre-generated `stepi` lemma - stepiTac stepi_eq h_step c + withMainContext' <| do + let stepiEq ← SymContext.findFromUserName stepi_eq.getId + stepiTac stepiEq.toExpr h_step.getId -- WORKAROUND: eventually we'd like to eagerly simp away `if`s in the -- pre-generation of instruction semantics. For now, though, we keep a -- `simp` here - withMainContext <| do + withMainContext' <| do let hStep ← SymContext.findFromUserName h_step.getId let lctx ← getLCtx - let decls := (c.h_sp?.bind lctx.findFromUserName?).toArray + let decls := (← getThe SymContext).h_sp?.bind lctx.findFromUserName? + let decls := decls.toArray -- If we know SP is aligned, `simp` with that fact if !decls.isEmpty then @@ -309,18 +301,61 @@ def sym1 (c : SymContext) (whileTac : TSyntax `tactic) : TacticM SymContext := skipping simplification step" -- Prepare `h_program`,`h_err`,`h_pc`, etc. for next state - withMainContext <| do + withMainContext' <| do let hStep ← SymContext.findFromUserName h_step.getId -- ^^ we can't reuse `hStep` from before, since its fvarId might've been -- changed by `simp` - let c ← explodeStep c hStep.toExpr + explodeStep hStep.toExpr + prepareForNextStep let goal ← getMainGoal let goal ← goal.clear hStep.fvarId replaceMainGoal [goal] traceHeartbeats - return c + +/-- `ensureLessThanRunSteps n` will +- log a warning and return `m`, if `runSteps? = some m` and `m < n`, or +- return `n` unchanged, otherwise -/ +def ensureAtMostRunSteps (n : Nat) : SymM Nat := do + let ctx ← getThe SymContext + match ctx.runSteps? with + | none => pure n + | some runSteps => + if n ≤ runSteps then + pure n + else + withMainContext <| do + let hRun ← ctx.hRunDecl + logWarning m!"Symbolic simulation is limited to at most {runSteps} \ + steps, because {hRun.toExpr} is of type:\n {hRun.type}" + pure runSteps + return n + +/-- Check that the step-thoerem corresponding to the current PC value exists, +and throw a user-friendly error, pointing to `#genStepEqTheorems`, +if it does not. -/ +def assertStepTheoremsGenerated : SymM Unit := do + let c ← getThe SymContext + let pc := c.pc.toHexWithoutLeadingZeroes + if !c.programInfo.instructions.contains c.pc then + let pcEff ← AxEffects.getFieldM .PC + throwError "\ + Program {c.program} has no instruction at address {c.pc}. + + We inferred this address as the program-counter from {pcEff.proof}, \ + which has type: + {← inferType pcEff.proof}" + + let step_thm := Name.str c.program ("stepi_eq_0x" ++ pc) + try + let _ ← getConstInfo step_thm + catch err => + throwErrorAt err.getRef "{err.toMessageData}\n +Did you remember to generate step theorems with: + #genStepEqTheorems {c.program}" +-- TODO: can we make this error ^^ into a `Try this:` suggestion that +-- automatically adds the right command just before the theorem? /- used in `sym_n` tactic to specify an initial state -/ syntax sym_at := "at" ident @@ -368,49 +403,30 @@ elab "sym_n" whileTac?:(sym_while)? n:num s:(sym_at)? : tactic => do omega; )) - Lean.Elab.Tactic.withMainContext <| do - let mut c ← SymContext.fromLocalContext s - c ← c.addGoalsForMissingHypotheses - c.canonicalizeHypothesisTypes - - -- Check that we are not asked to simulate more steps than available - let n ← do - let n := n.getNat - match c.runSteps? with - | none => pure n - | some runSteps => - if n ≤ runSteps then - pure n - else - let h_run ← userNameToMessageData c.h_run - logWarning m!"Symbolic simulation using {h_run} is limited to at most {runSteps} steps" - pure runSteps - - -- Check that step theorems have been pre-generated - try - let pc := c.pc.toHexWithoutLeadingZeroes - let step_thm := Name.str c.program ("stepi_eq_0x" ++ pc) - let _ ← getConstInfo step_thm - catch err => - throwErrorAt err.getRef "{err.toMessageData}\n -Did you remember to generate step theorems with: - #genStepEqTheorems {c.program}" --- TODO: can we make this error ^^ into a `Try this:` suggestion that --- automatically adds the right command just before the theorem? + let c ← withMainContext <| SymContext.fromLocalContext s + SymM.run' c <| do + -- Context preparation + addGoalsForMissingHypotheses + canonicalizeHypothesisTypes + + -- Check pre-conditions + assertStepTheoremsGenerated + let n ← ensureAtMostRunSteps n.getNat - -- The main loop - for _ in List.range n do - c ← sym1 c whileTac + withMainContext' <| do + -- The main loop + for _ in List.range n do + sym1 whileTac traceHeartbeats "symbolic simulation total" + let c ← getThe SymContext -- Check if we can substitute the final state if c.runSteps? = some 0 then let msg := do let hRun ← userNameToMessageData c.h_run pure m!"runSteps := 0, substituting along {hRun}" - withTraceNode `Tactic.sym (fun _ => msg) <| withMainContext do - let s ← SymContext.findFromUserName c.state - let sfEq ← mkEq s.toExpr c.finalState + withTraceNode `Tactic.sym (fun _ => msg) <| withMainContext' do + let sfEq ← mkEq (← getCurrentState) c.finalState let goal ← getMainGoal trace[Tactic.sym] "original goal:\n{goal}" @@ -430,7 +446,7 @@ Did you remember to generate step theorems with: -- Rudimentary aggregation: we feed all the axiomatic effect hypotheses -- added while symbolically evaluating to `simp` let msg := m!"aggregating (non-)effects" - withTraceNode `Tactic.sym (fun _ => pure msg) <| withMainContext do + withTraceNode `Tactic.sym (fun _ => pure msg) <| withMainContext' do traceHeartbeats "pre" let goal? ← LNSymSimp (← getMainGoal) c.aggregateSimpCtx c.aggregateSimprocs replaceMainGoal goal?.toList diff --git a/Tactics/SymContext.lean b/Tactics/SymContext.lean index 06e744f8..e672e47f 100644 --- a/Tactics/SymContext.lean +++ b/Tactics/SymContext.lean @@ -35,8 +35,6 @@ open BitVec /-- A `SymContext` collects the names of various variables/hypotheses in the local context required for symbolic evaluation -/ structure SymContext where - /-- `state` is a local variable of type `ArmState` -/ - state : Name /-- `finalState` is an expression of type `ArmState` -/ finalState : Expr /-- `runSteps?` stores the number of steps that we can *maximally* simulate, @@ -54,13 +52,16 @@ structure SymContext where See also `SymContext.runSteps?` -/ h_run : Name - /-- `program` is a *constant* which represents the program being evaluated -/ - program : Name - /-- `h_program` is a local hypothesis of the form `state.program = program` -/ - h_program : Name /-- `programInfo` is the relevant cached `ProgramInfo` -/ programInfo : ProgramInfo + /-- the effects of the current state, such as: + - a proof that the PC is equal to `pc` + - a proof that the current state is valid (`read_err _ = .None`) + - a proof that the current state has the right program + - and more, see `AxEffects` for the full list -/ + effects : AxEffects + /-- `pc` is the *concrete* value of the program counter Note that for now we only support symbolic evaluation of programs @@ -74,8 +75,6 @@ structure SymContext where and we assume that no overflow happens (i.e., `base - x` can never be equal to `base + y`) -/ pc : BitVec 64 - /-- `h_pc` is a local hypothesis of the form `r StateField.PC state = pc` -/ - h_pc : Name /-- `h_err?`, if present, is a local hypothesis of the form `r StateField.ERR state = .None` -/ h_err? : Option Name @@ -98,7 +97,56 @@ structure SymContext where /-- `curr_state_number` is incremented each simulation step, and used together with `curr_state_number` to determine the name of the next state variable that is added by `sym` -/ - curr_state_number : Nat := 0 + currentStateNumber : Nat := 0 + +/-! ## Monad -/ + +/-- `SymM` is a wrapper around `TacticM` with a mutable `SymContext` state -/ +abbrev SymM := StateT SymContext TacticM + +/-- `SymReaderM` is a wrapper around `TacticM` with a read-only `SymContext` state -/ +abbrev SymReaderM := ReaderT SymContext TacticM + +namespace SymM + +def run (ctx : SymContext) (k : SymM α) : TacticM (α × SymContext) := + StateT.run k ctx + +def run' (ctx : SymContext) (k : SymM α) : TacticM α := + StateT.run' k ctx + +instance : MonadLift SymReaderM SymM where + monadLift x c := do return (←x c, c) + +instance : MonadReaderOf AxEffects SymReaderM where + read := do return (← readThe SymContext).effects + +instance : MonadStateOf AxEffects SymM where + get := readThe AxEffects + set effects := do modifyThe SymContext ({· with effects}) + modifyGet f := do + let (a, effects) := f (← getThe SymContext).effects + modifyThe SymContext ({· with effects}) + return a + +/-! +## WORKAROUND for https://github.com/leanprover/lean4/issues/5457 +For some reason, `logWarning` is very slow to elaborate, +so we add a specialized `SymM.logWarning` with a specific instance of `MonadLog` +hidden behind a def. For some reason this is fast to elaborate. +-/ + +/-- This def may seem pointless, but it is in-fact load-bearing. + +Furthermore, making it an `instance` will cause `logWarning` below to be +very slow to elaborate. Why? No clue. -/ +protected def instMonadLog : MonadLog SymM := inferInstance + +@[inherit_doc Lean.logWarning] +def logWarning (msg : MessageData) : SymM Unit := + @Lean.logWarning SymM _ SymM.instMonadLog _ _ msg + +end SymM namespace SymContext @@ -107,23 +155,8 @@ section open Lean (Ident mkIdent) variable (c : SymContext) -/-- `next_state` generates the name for the next intermediate state -/ -def next_state (c : SymContext) : Name := - .mkSimple s!"{c.state_prefix}{c.curr_state_number + 1}" - -/-- return `h_err?` if given, or a default hardcoded name -/ -def h_err : Name := c.h_err?.getD (.mkSimple s!"h_{c.state}_err") - -/-- return `h_sp?` if given, or a default hardcoded name -/ -def h_sp : Name := c.h_err?.getD (.mkSimple s!"h_{c.state}_sp") - -def state_ident : Ident := mkIdent c.state -def next_state_ident : Ident := mkIdent c.next_state -def h_run_ident : Ident := mkIdent c.h_run -def h_program_ident : Ident := mkIdent c.h_program -def h_pc_ident : Ident := mkIdent c.h_pc -def h_err_ident : Ident := mkIdent c.h_err -def h_sp_ident : Ident := mkIdent c.h_sp +/-- `program` is a *constant* which represents the program being evaluated -/ +def program : Name := c.programInfo.name /-- Find the local declaration that corresponds to a given name, or throw an error if no local variable of that name exists -/ @@ -132,57 +165,92 @@ def findFromUserName (name : Name) : MetaM LocalDecl := do | throwError "Unknown local variable `{name}`" return decl -/-- Return an expression for `c.state`, -or throw an error if no local variable of that name exists -/ -def stateExpr : MetaM Expr := - (·.toExpr) <$> findFromUserName c.state - /-- Find the local declaration that corresponds to `c.h_run`, or throw an error if no local variable of that name exists -/ def hRunDecl : MetaM LocalDecl := do findFromUserName c.h_run +section Monad +variable {m} [Monad m] [MonadReaderOf SymContext m] + +def getCurrentStateNumber : m Nat := do return (← read).currentStateNumber + +/-- Return the name of the hypothesis + `h_run : = run ` -/ +def getHRunName : m Name := do return (← read).h_run + +/-- Retrieve the name for the next state + +NOTE: `getNextStateName` does not increment the state, so consecutive calls +will give the same name. Calling `prepareForNextStep` will increment the state. +-/ +def getNextStateName : m Name := do + let c ← read + return Name.mkSimple s!"{c.state_prefix}{c.currentStateNumber + 1}" + +end Monad + end -/-! ## `ToMessageData` instance -/ +/-! ## `ToMessageData` instance and tracing -/ /-- Convert a `SymContext` to `MessageData` for tracing. This is not a `ToMessageData` instance because we need access to `MetaM` -/ def toMessageData (c : SymContext) : MetaM MessageData := do - let state ← c.stateExpr let h_run ← userNameToMessageData c.h_run - let h_err? ← c.h_err?.mapM userNameToMessageData let h_sp? ← c.h_sp?.mapM userNameToMessageData - return m!"\{ state := {state}, - finalState := {c.finalState}, + return m!"\{ finalState := {c.finalState}, runSteps? := {c.runSteps?}, h_run := {h_run}, program := {c.program}, pc := {c.pc}, - h_err? := {h_err?}, h_sp? := {h_sp?}, state_prefix := {c.state_prefix}, - curr_state_number := {c.curr_state_number} }" + curr_state_number := {c.currentStateNumber}, + effects := {c.effects} }" + +variable {α : Type} {m : Type → Type} [Monad m] [MonadTrace m] [MonadLiftT IO m] + [MonadRef m] [AddMessageContext m] [MonadOptions m] {ε : Type} + [MonadAlwaysExcept ε m] [MonadLiftT BaseIO m] in +def withSymTraceNode (msg : MessageData) (k : m α) : m α := do + withTraceNode `Tactic.sym (fun _ => pure msg) k + +def traceSymContext : SymM Unit := + withTraceNode `Tactic.sym (fun _ => pure m!"SymContext: ") <| do + let m ← (← getThe SymContext).toMessageData + trace[Tactic.sym] m /-! ## Creating initial contexts -/ +/-- Modify a `SymContext` with a monadic action `k : SymM Unit` -/ +def modify (ctxt : SymContext) (k : SymM Unit) : TacticM SymContext := do + let ((), ctxt) ← SymM.run ctxt k + return ctxt + /-- Infer `state_prefix` and `curr_state_number` from the `state` name as follows: if `state` is `s{i}` for some number `i` and a single character `s`, then `s` is the prefix and `i` the number, -otherwise ignore `state`, and start counting from `s1` -/ -def inferStatePrefixAndNumber (ctxt : SymContext) : SymContext := - let state := ctxt.state.toString +otherwise ignore `state`, log a warning, and start counting from `s1` -/ +def inferStatePrefixAndNumber : SymM Unit := do + let state ← AxEffects.getCurrentStateName + let state := state.toString let tail := state.toSubstring.drop 1 - if let some curr_state_number := tail.toNat? then - { ctxt with + if let some currentStateNumber := tail.toNat? then + modifyThe SymContext ({ · with state_prefix := (state.get? 0).getD 's' |>.toString, - curr_state_number } + currentStateNumber }) else - { ctxt with + SymM.logWarning "\ + Expected state to be a single letter followed by a number, but found: + {state} + + Falling back to the default numbering schema, + with `s1` as the first new intermediate state" + modifyThe SymContext ({ · with state_prefix := "s", - curr_state_number := 1 } + currentStateNumber := 1 }) /-- Annotate any errors thrown by `k` with a local variable (and its type) -/ private def withErrorContext (name : Name) (type? : Option Expr) (k : MetaM α) : @@ -195,8 +263,13 @@ private def withErrorContext (name : Name) (type? : Option Expr) (k : MetaM α) throwErrorAt e.getRef "{e.toMessageData}\n\nIn {h}{type}" /-- Build a `SymContext` by searching the local context for hypotheses of the -required types (up-to defeq) -/ -def fromLocalContext (state? : Option Name) : MetaM SymContext := do +required types (up-to defeq) . The local context is modified to unfold the types +to be syntactically equal to the expected type. + +If an hypothesis `h_err : r .ERR = None` is not found, +we create a new subgoal of this type +-/ +def fromLocalContext (state? : Option Name) : TacticM SymContext := do let msg := m!"Building a `SymContext` from the local context" withTraceNode `Tactic.sym (fun _ => pure msg) do trace[Tactic.Sym] "state? := {state?}" @@ -229,37 +302,19 @@ def fromLocalContext (state? : Option Name) : MetaM SymContext := do -- At this point, `stateExpr` should have been assigned (if it was an mvar), -- so we can unwrap it to get the underlying name let stateExpr ← instantiateMVars stateExpr - let state ← state?.getDM <| do - let .fvar state := stateExpr - | let h_run_type ← instantiateMVars h_run.type - let h_run := h_run.toExpr - throwError - "Expected a free variable, found: - {stateExpr} - We inferred this as the initial state because we found: - {h_run} : {h_run_type} - in the local context. - - If this is wrong, please explicitly provide the right initial state, - as in `sym {runSteps} at ?s0` - " - let some state := lctx.find? state - /- I don't expect this error to be possible in a well-formed state, - but you never know -/ - | throwError "Failed to find local variable for state {stateExpr}" - pure state.userName -- Try to find `h_program`, and infer `program` from it let ⟨h_program, program⟩ ← withErrorContext h_run.userName h_run.type <| findProgramHyp stateExpr -- Then, try to find `h_pc` - let pc ← mkFreshExprMVar (← mkAppM ``BitVec #[toExpr 64]) - let h_pc ← findLocalDeclOfTypeOrError <| h_pc_type stateExpr pc + let pcE ← mkFreshExprMVar (← mkAppM ``BitVec #[toExpr 64]) + let h_pc ← findLocalDeclOfTypeOrError <| h_pc_type stateExpr pcE -- Unwrap and reflect `pc` - let pc ← instantiateMVars pc - let pc ← withErrorContext h_pc.userName h_pc.type <| reflectBitVecLiteral 64 pc + let pcE ← instantiateMVars pcE + let pc ← withErrorContext h_pc.userName h_pc.type <| + reflectBitVecLiteral 64 pcE -- Attempt to find `h_err` and `h_sp` let h_err? ← findLocalDeclOfType? (h_err_type stateExpr) @@ -283,16 +338,31 @@ def fromLocalContext (state? : Option Name) : MetaM SymContext := do (decls := axHyps) (noIndexAtArgs := false) - return inferStatePrefixAndNumber { - state, finalState, runSteps?, program, pc, + -- Build an initial AxEffects + let effects := AxEffects.initial stateExpr + let mut fields := + effects.fields.insert .PC { value := pcE, proof := h_pc.toExpr} + if let some hErr := h_err? then + fields := fields.insert .ERR { + value := mkConst ``StateError.None, + proof := hErr.toExpr + } + let effects := { effects with + programProof := h_program.toExpr + stackAlignmentProof? := h_sp?.map (·.toExpr) + fields + } + let c : SymContext := { + finalState, runSteps?, pc, h_run := h_run.userName, - h_program := h_program.userName, - h_pc := h_pc.userName h_err? := (·.userName) <$> h_err?, h_sp? := (·.userName) <$> h_sp?, programInfo, + effects, aggregateSimpCtx, aggregateSimprocs } + c.modify <| + inferStatePrefixAndNumber where findLocalDeclOfType? (expectedType : Expr) : MetaM (Option LocalDecl) := do let msg := m!"Searching for hypothesis of type: {expectedType}" @@ -307,30 +377,25 @@ where trace[Tactic.sym] "Found: {decl.toExpr}" return decl - - /-! ## Massaging the local context -/ /-- If `h_sp` or `h_err` are missing from the `SymContext`, add new goals of the expected types, and use these to add `h_sp` and `h_err` to the main goal context -/ -def addGoalsForMissingHypotheses (ctx : SymContext) (addHSp : Bool := false) : - TacticM SymContext := +def addGoalsForMissingHypotheses (addHSp : Bool := false) : SymM Unit := let msg := "Adding goals for missing hypotheses" - withTraceNode `Tactic.sym (fun _ => pure msg) <| withMainContext do - let mut ctx := ctx + withTraceNode `Tactic.sym (fun _ => pure msg) <| withMainContext' do + let mut ctx ← getThe SymContext let mut goal ← getMainGoal let mut newGoals := [] - let lCtx ← getLCtx - let some stateExpr := - (Expr.fvar ·.fvarId) <$> lCtx.findFromUserName? ctx.state - | throwError "Could not find '{ctx.state}' in the local context" + let stateExpr ← AxEffects.getCurrentState + let stateName ← AxEffects.getCurrentStateName match ctx.h_err? with | none => trace[Tactic.sym] "h_err? is none, adding a new goal" - let h_err? := Name.mkSimple s!"h_{ctx.state}_run" + let h_err? := Name.mkSimple s!"h_{stateName}_err" let newGoal ← mkFreshMVarId goal := ← do @@ -341,7 +406,10 @@ def addGoalsForMissingHypotheses (ctx : SymContext) (addHSp : Bool := false) : return goal' newGoals := newGoal :: newGoals - ctx := { ctx with h_err? } + ctx := { ctx with + h_err? + effects := ← ctx.effects.withField (.mvar newGoal) + } | some h_err => let h_err ← userNameToMessageData h_err trace[Tactic.sym] "h_err? is {h_err}, no new goal needed" @@ -351,7 +419,7 @@ def addGoalsForMissingHypotheses (ctx : SymContext) (addHSp : Bool := false) : if addHSp then trace[Tactic.sym] "h_sp? is none, adding a new goal" - let h_sp? := Name.mkSimple s!"h_{ctx.state}_sp" + let h_sp? := Name.mkSimple s!"h_{stateName}_sp" let newGoal ← mkFreshMVarId goal := ← do @@ -362,7 +430,10 @@ def addGoalsForMissingHypotheses (ctx : SymContext) (addHSp : Bool := false) : return goal' newGoals := newGoal :: newGoals - ctx := { ctx with h_sp? } + ctx := { ctx with + h_sp? + effects.stackAlignmentProof? := some (Expr.mvar newGoal) + } else trace[Tactic.sym] "h_sp? is none, but addHSp is false, \ so no new goal is added" @@ -371,7 +442,7 @@ def addGoalsForMissingHypotheses (ctx : SymContext) (addHSp : Bool := false) : trace[Tactic.sym] "h_sp? is {h_sp}, no new goal needed" replaceMainGoal (goal :: newGoals) - return ctx + set ctx /-- change the type (in the local context of the main goal) of the hypotheses tracked by the given `SymContext` to be *exactly* of the shape @@ -379,16 +450,12 @@ described in the relevant docstrings. That is, (un)fold types which were definitionally, but not syntactically, equal to the expected shape. -/ -def canonicalizeHypothesisTypes (c : SymContext) : TacticM Unit := withMainContext do +def canonicalizeHypothesisTypes : SymReaderM Unit := fun c => withMainContext do let lctx ← getLCtx let mut goal ← getMainGoal - let state ← c.stateExpr - let program := mkConst c.program + let state := c.effects.currentState - let mut hyps := #[ - (c.h_program, h_program_type state program), - (c.h_pc, h_pc_type state (toExpr c.pc)) - ] + let mut hyps := #[] if let some runSteps := c.runSteps? then hyps := hyps.push (c.h_run, h_run_type c.finalState (toExpr runSteps) state) if let some h_err := c.h_err? then @@ -398,30 +465,34 @@ def canonicalizeHypothesisTypes (c : SymContext) : TacticM Unit := withMainConte for ⟨name, type⟩ in hyps do let some decl := lctx.findFromUserName? name - | throwError "Unknown local hypothesis `{c.state}`" + | throwError "Unknown local hypothesis `{name}`" goal ← goal.replaceLocalDeclDefEq decl.fvarId type replaceMainGoal [goal] /-! ## Incrementing the context to the next state -/ -/-- `c.next` generates names for the next intermediate state and its hypotheses - -`nextPc?`, if given, will be the pc of the next context. -If `nextPC?` is `none`, then the previous pc is incremented by 4 -/ -def next (c : SymContext) (nextPc? : Option (BitVec 64) := none) : - SymContext := - let curr_state_number := c.curr_state_number + 1 - let s := c.next_state - { c with - state := s - h_program := .mkSimple s!"h_{s}_program" - h_pc := .mkSimple s!"h_{s}_pc" - h_err? := c.h_err?.map (fun _ => .mkSimple s!"h_{s}_err") +/-- `prepareForNextStep` prepares the state for the next step of symbolic +evaluation: + * `pc` is reflected from the stored `effects` + * `runSteps?`, if specified, is decremented, + * the `currentStateNumber` is incremented +-/ +def prepareForNextStep : SymM Unit := do + let s ← getNextStateName + let pc ← do + let { value, ..} ← AxEffects.getFieldM .PC + try + reflectBitVecLiteral 64 value + catch err => + trace[Tactic.sym] "failed to reflect PC: {err.toMessageData}" + pure <| (← getThe SymContext).pc + 4 + + modifyThe SymContext (fun c => { c with + pc h_sp? := c.h_sp?.map (fun _ => .mkSimple s!"h_{s}_sp_aligned") runSteps? := (· - 1) <$> c.runSteps? - pc := nextPc?.getD (c.pc + 4#64) - curr_state_number - } + currentStateNumber := c.currentStateNumber + 1 + }) /-- Add a set of new simp-theorems to the simp-theorems used for effect aggregation -/ diff --git a/lakefile.lean b/lakefile.lean index cc494af8..80f20126 100644 --- a/lakefile.lean +++ b/lakefile.lean @@ -1,3 +1,8 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Shilpi Goel +-/ import Lake open Lake DSL