Skip to content

Commit

Permalink
Merge branch 'main' into refactor-state-monads-6
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkeizer authored Sep 30, 2024
2 parents 46054db + 16e3f4f commit 02e3456
Show file tree
Hide file tree
Showing 12 changed files with 452 additions and 54 deletions.
207 changes: 207 additions & 0 deletions Arm/Memory/AddressNormalization.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Siddharth Bhat, Tobias Grosser
-/

/-
This file implements bitvector expression simplification simprocs.
We perform the following additional changes:
1. Canonicalizing bitvector expression to always have constants on the left.
Recall that the default associativity of addition is to the left: x + y + z = (x + y) + z.
If we thus normalize our expressions to have constants on the left,
and if we constant-fold constants, we will naturally perform canonicalization.
That is, the two rewrites:
(a) (x + c) -> (c + x).
(b) x + (y + z) -> (x + y) + z.
combine to achieve constant folding. Observe an example:
((x + 10) + 20)
-b-> (20 + (x + 10))
-b-> (20 + (10 + x))
-a-> (20 + 10) + x
-reduceAdd-> 30 + x
2. Canonicalizing (a + b) % n → a % n + b % n by exploiting `bv_omega`,
and eventually, `simp_mem`.
-/
import Lean
import Arm.Memory.Attr
import Arm.Attr
import Tactics.Common

open Lean Meta Elab Simp


theorem Nat.mod_eq_sub {x y : Nat} (h : x ≥ y) (h' : x - y < y) :
x % y = x - y := by
rw [Nat.mod_eq_sub_mod h, Nat.mod_eq_of_lt h']

private def mkLTNat (x y : Expr) : Expr :=
mkAppN (.const ``LT.lt [levelZero]) #[mkConst ``Nat, mkConst ``instLTNat, x, y]

private def mkLENat (x y : Expr) : Expr :=
mkAppN (.const ``LE.le [levelZero]) #[mkConst ``Nat, mkConst ``instLENat, x, y]

private def mkGENat (x y : Expr) : Expr := mkLENat y x

private def mkSubNat (x y : Expr) : Expr :=
let lz := levelZero
let nat := mkConst ``Nat
let instSub := mkConst ``instSubNat
let instHSub := mkAppN (mkConst ``instHSub [lz]) #[nat, instSub]
mkAppN (mkConst ``HSub.hSub [lz, lz, lz]) #[nat, nat, nat, instHSub, x, y]

/--
Given an expression of the form `n#w`, return the value of `n` if it is a ground constant.
Notice that this is different from `getBitVecValue?` in that here we allow `w` to be symbolic.
Hence, we might not know the width, explaining why we return a `Nat` rather than a `BitVec`.
-/
def getBitVecOfNatValue? (e : Expr) : (Option (Expr × Expr)) :=
match_expr e with
| BitVec.ofNat nExpr vExpr => some (nExpr, vExpr)
| _ => none

/--
Try to build a proof for `ty` by reduction to `omega`.
Return a proof of `ty` on success, or `none` if omega failed to prove the goal.
This is to be used to automatically prove inbounds constraints to eliminate modulos
in a simproc, hence the use of `SimpM`.
We may eventually want to exploit our memory automation framework to bring in
more `omega` facts.
-/
@[inline] def dischargeByOmega (ty : Expr) : SimpM (Option Expr) := do
let proof : Expr ← mkFreshExprMVar ty
let g := proof.mvarId!
let some g ← g.falseOrByContra
| return none
try
g.withContext (do Lean.Elab.Tactic.Omega.omega (← getLocalHyps).toList g {})
catch _ =>
return none
return some proof

-- x % n = x if x < n
@[inline] def reduceModOfLt (x : Expr) (n : Expr) : SimpM Step := do
trace[Tactic.address_normalization] "{processingEmoji} reduceModOfLt '{x} % {n}'"
let ltTy := mkLTNat x n
let some p ← dischargeByOmega ltTy
| return .continue
let eqProof ← mkAppM ``Nat.mod_eq_of_lt #[p]
trace[Tactic.address_normalization] "{checkEmoji} reduceModOfLt '{x} % {n}'"
return .done { expr := x, proof? := eqProof : Result }

-- x % n = x - n if x >= n and x - n < n
@[inline] def reduceModSub (x : Expr) (n : Expr) : SimpM Step := do
trace[Tactic.address_normalization] "{processingEmoji} reduceModSub '{x} % {n}'"
let geTy := mkGENat x n
let some geProof ← dischargeByOmega geTy
| return .continue
let subTy := mkSubNat x n
let ltTy := mkLTNat subTy n
let some ltProof ← dischargeByOmega ltTy
| return .continue
let eqProof ← mkAppM ``Nat.mod_eq_sub #[geProof, ltProof]
trace[Tactic.address_normalization] "{checkEmoji} reduceModSub '{x} % {n}'"
return .done { expr := subTy, proof? := eqProof : Result }

@[inline, bv_toNat] def reduceMod (e : Expr) : SimpM Step := do
match_expr e with
| HMod.hMod xTy nTy outTy _inst x n =>
let natTy := mkConst ``Nat
if (xTy != natTy) || (nTy != natTy) || (outTy != natTy) then
return .continue
if let .done res ← reduceModOfLt x n then
return .done res
if let .done res ← reduceModSub x n then
return .done res
return .continue
| _ => do
return .continue

simproc↑ [address_normalization] reduce_mod_omega (_ % _) := fun e => reduceMod e

/-- Canonicalize a commutative binary operation.
1. If both arguments are constants, we perform constant folding.
2. If only one of the arguments is a constant, we move the constant to the left.
-/
@[inline, bv_toNat] def canonicalizeBinConst (declName : Name) -- operator to constant fold, such as `HAdd.hAdd`.
(arity : Nat)
(commProofDecl : Name) -- commProof: `∀ (x y : Bitvec w), x op y = y op x`.
(reduceProofDecl : Name) -- reduce proof: `∀ (w : Nat), (n m : Nat) (BitVec.ofNat w n) op (BitVec.ofNat w m) = BitVec.ofNat w (n op' m)`.
(fxy : Expr) : SimpM Step := do
unless fxy.isAppOfArity declName arity do return .continue
let fx := fxy.appFn!
let x := fx.appArg!
let f := fx.appFn!
let y := fxy.appArg!
trace[Tactic.address_normalization] "{processingEmoji} canonicalizeBinConst '({f} {x} {y})'"
match getBitVecOfNatValue? x with
| some (xwExpr, xvalExpr) =>
-- We have a constant on the left, check if we have a constant on the right
-- so we can fully constant fold the expression.
let .some (_, yvalExpr) := getBitVecOfNatValue? y
| return .continue

let e' ← mkAppM reduceProofDecl #[xwExpr, xvalExpr, yvalExpr]
trace[Tactic.address_normalization] "{checkEmoji} canonicalizeBinConst '({f} {x} {y})'"
return .done { expr := e', proof? := ← mkAppM reduceProofDecl #[x, y] : Result }

| none =>
-- We don't have a constant on the left, check if we have a constant on the right
-- and try to move it to the left.
let .some _ := getBitVecOfNatValue? y
| return .continue -- no constants on either side, nothing to do.

-- Nothing more to to do, except to move the right constant to the left.
let e' := mkAppN f #[y, x]
trace[Tactic.address_normalization] "{checkEmoji} canonicalizeBinConst '({f} {x} {y})'"
return .done { expr := e', proof? := ← mkAppM commProofDecl #[x, y] : Result }

/- Change `100` to `100#64` so we can pattern match to `BitVec.ofNat` -/
attribute [address_normalization] BitVec.ofNat_eq_ofNat


theorem BitVec.add_ofNat_eq_ofNat_add {w n} (x : BitVec w) :
x + BitVec.ofNat w n = BitVec.ofNat w n + x := by
apply BitVec.add_comm


theorem BitVec.mul_ofNat_eq_ofNat_mul {w n} (x : BitVec w) :
x * BitVec.ofNat w n = BitVec.ofNat w n * x := by
apply BitVec.mul_comm

simproc [address_normalization] constFoldAdd ((_ + _ : BitVec _)) :=
canonicalizeBinConst ``HAdd.hAdd 6 ``BitVec.add_comm ``BitVec.add_ofNat_eq_ofNat_add

simproc [address_normalization] constFoldMul ((_ * _ : BitVec _)) :=
canonicalizeBinConst ``HMul.hMul 6 ``BitVec.mul_comm ``BitVec.mul_ofNat_eq_ofNat_mul

@[address_normalization]
theorem BitVec.ofNat_add_ofNat_eq_add_ofNat (w : Nat) (n m : Nat) :
BitVec.ofNat w n + BitVec.ofNat w m = BitVec.ofNat w (n + m) := by
apply BitVec.eq_of_toNat_eq
simp

@[address_normalization]
theorem BitVec.ofNat_mul_ofNat_eq_mul_ofNat (w : Nat) (n m : Nat) :
BitVec.ofNat w n * BitVec.ofNat w m = BitVec.ofNat w (n * m) := by
apply BitVec.eq_of_toNat_eq
-- Note that omega cannot close the goal since it's symbolic multiplication.
simp only [toNat_mul, toNat_ofNat, ← Nat.mul_mod]

/-- Reassociate addition to left. -/
@[address_normalization]
theorem BitVec.add_assoc_symm {w} (x y z : BitVec w) : x + (y + z) = x + y + z := by
rw [BitVec.add_assoc]

/-- Reassociate multiplication to left. -/
@[address_normalization]
theorem BitVec.mul_assoc_symm {w} (x y z : BitVec w) : x * (y * z) = x * y * z := by
rw [BitVec.mul_assoc]
6 changes: 6 additions & 0 deletions Arm/Memory/Attr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ initialize Lean.registerTraceClass `simp_mem
/-- Provides extremely verbose tracing for the `simp_mem` tactic. -/
initialize Lean.registerTraceClass `simp_mem.info

/-- Provides extremely verbose tracing for the `simp_mem` tactic. -/
initialize Lean.registerTraceClass `Tactic.address_normalization

-- Rules for simprocs that mine the state to extract information for `omega`
-- to run.
register_simp_attr memory_omega

-- Simprocs for address normalization
register_simp_attr address_normalization
1 change: 1 addition & 0 deletions Arm/Memory/SeparateAutomation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Arm
import Arm.Memory.MemoryProofs
import Arm.BitVec
import Arm.Memory.Attr
import Arm.Memory.AddressNormalization
import Lean
import Lean.Meta.Tactic.Rewrite
import Lean.Meta.Tactic.Rewrites
Expand Down
2 changes: 2 additions & 0 deletions Arm/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ These mnemonics make it much easier to read and write theorems about assembly pr
-/

@[state_simp_rules] abbrev ArmState.x0 (s : ArmState) : BitVec 64 := r (StateField.GPR 0) s
@[state_simp_rules] abbrev ArmState.w0 (s : ArmState) : BitVec 32 :=
(r (StateField.GPR 0) s).zeroExtend 32

@[state_simp_rules] abbrev ArmState.x1 (s : ArmState) : BitVec 64 := r (StateField.GPR 1) s

Expand Down
25 changes: 25 additions & 0 deletions Arm/Syntax.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,36 @@ Author(s): Siddharth Bhat
Provide convenient syntax for writing down state manipulation in Arm programs.
-/
import Arm.State
import Arm.Memory.Separate

namespace ArmStateNotation

/-! We build a notation for `read_mem_bytes $n $base $s` as `$s[$base, $n]` -/
@[inherit_doc read_mem_bytes]
syntax:max term noWs "[" withoutPosition(term) "," withoutPosition(term) noWs "]" : term
macro_rules | `($s[$base,$n]) => `(read_mem_bytes $n $base $s)


/-! Notation to specify the frame condition for non-memory state components. E.g.,
`REGS_UNCHANGED_EXCEPT [.GPR 0, .PC] (sf, s0)` is sugar for
`∀ f, f ∉ [.GPR 0, .PC] → r f sf = r f s0`
-/
syntax:max "REGS_UNCHANGED_EXCEPT" "[" term,* "]"
"(" withoutPosition(term) "," withoutPosition(term) ")" : term
macro_rules
| `(REGS_UNCHANGED_EXCEPT [$regs:term,*] ($sf, $s0)) =>
`(∀ f, f ∉ [$regs,*] → r f $sf = r f $s0)

/-! Notation to specify the frame condition for memory regions. E.g.,
`MEM_UNCHANGED_EXCEPT [(x, m), (y, k)] (sf, s0)` is sugar for
`∀ n addr, Memory.Region.pairwiseSeparate [(addr, n), (x, m), (y, k)] → sf[addr, n] = s0[addr, n]`
-/
syntax:max "MEM_UNCHANGED_EXCEPT" "[" term,* "]"
"(" withoutPosition(term) "," withoutPosition(term) ")" : term
macro_rules |
`(MEM_UNCHANGED_EXCEPT [$mem:term,*] ($sf, $s0)) =>
`(∀ (n : Nat) (addr : BitVec 64),
Memory.Region.pairwiseSeparate (List.cons (addr, n) [$mem,*]) →
read_mem_bytes n addr $sf = read_mem_bytes n addr $s0)

end ArmStateNotation
74 changes: 73 additions & 1 deletion Proofs/AES-GCM/GCMGmultV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,93 @@ Author(s): Alex Keizer
-/
import Tests.«AES-GCM».GCMGmultV8Program
import Tactics.Sym
import Tactics.Aggregate
import Tactics.StepThms
import Tactics.CSE
import Arm.Memory.SeparateAutomation
import Arm.Syntax

namespace GCMGmultV8Program
open ArmStateNotation

#genStepEqTheorems gcm_gmult_v8_program

/-
xxx: GCMGmultV8 Xi HTable
-/

set_option pp.deepTerms false in
set_option pp.deepTerms.threshold 50 in
-- set_option trace.simp_mem.info true in
theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
(h_s0_program : s0.program = gcm_gmult_v8_program)
(h_s0_err : read_err s0 = .None)
(h_s0_pc : read_pc s0 = gcm_gmult_v8_program.min)
(h_s0_sp_aligned : CheckSPAlignment s0)
(h_Xi : Xi = s0[read_gpr 64 0#5 s0, 16])
(h_HTable : HTable = s0[read_gpr 64 1#5 s0, 256])
(h_mem_sep : Memory.Region.pairwiseSeparate
[(read_gpr 64 0#5 s0, 16),
(read_gpr 64 1#5 s0, 256)])
(h_run : sf = run gcm_gmult_v8_program.length s0) :
read_err sf = .None := by
-- The final state is error-free.
read_err sf = .None ∧
-- The program is unmodified in `sf`.
sf.program = gcm_gmult_v8_program ∧
-- The stack pointer is still aligned in `sf`.
CheckSPAlignment sf ∧
-- The final state returns to the address in register `x30` in `s0`.
read_pc sf = r (StateField.GPR 30#5) s0 ∧
-- HTable is unmodified.
sf[read_gpr 64 1#5 s0, 256] = HTable ∧
-- Frame conditions.
-- Note that the following also covers that the Xi address in .GPR 0
-- is unmodified.
REGS_UNCHANGED_EXCEPT [.SFP 0, .SFP 1, .SFP 2, .SFP 3,
.SFP 17, .SFP 18, .SFP 19, .SFP 20,
.SFP 21, .PC]
(sf, s0) ∧
-- Memory frame condition.
MEM_UNCHANGED_EXCEPT [(r (.GPR 0) s0, 128)] (sf, s0) := by
simp_all only [state_simp_rules, -h_run] -- prelude
simp (config := {ground := true}) only at h_s0_pc
-- ^^ Still needed, because `gcm_gmult_v8_program.min` is somehow
-- unable to be reflected
sym_n 27
-- Epilogue
simp only [←Memory.mem_eq_iff_read_mem_bytes_eq] at *
simp only [memory_rules] at *
sym_aggregate
-- Split conjunction
repeat' apply And.intro
· -- Aggregate the memory (non)effects.
-- (FIXME) This will be tackled by `sym_aggregate` when `sym_n` and `simp_mem`
-- are merged.
simp only [*]
/-
(FIXME @bollu) `simp_mem; rfl` creates a malformed proof here. The tactic produces
no goals, but we get the following error message:
application type mismatch
Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset'
(Eq.mp (congrArg (Eq HTable) (Memory.State.read_mem_bytes_eq_mem_read_bytes s0))
(Eq.mp (congrArg (fun x => HTable = read_mem_bytes 256 x s0) zeroExtend_eq_of_r_gpr) h_HTable))
argument has type
HTable = Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem
but function has type
Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem = HTable →
mem_subset' (r (StateField.GPR 1#5) s0) 256 (r (StateField.GPR 1#5) s0) 256 →
Memory.read_bytes 256 (r (StateField.GPR 1#5) s0) s0.mem =
HTable.extractLsBytes (BitVec.toNat (r (StateField.GPR 1#5) s0) - BitVec.toNat (r (StateField.GPR 1#5) s0)) 256
simp_mem; rfl
-/
rw [Memory.read_bytes_write_bytes_eq_read_bytes_of_mem_separate']
simp_mem
· simp only [List.mem_cons, List.mem_singleton, not_or, and_imp]
sym_aggregate
· intro n addr h_separate
simp_mem (config := { useOmegaToClose := false })
-- Aggregate the memory (non)effects.
simp only [*]
done
Loading

0 comments on commit 02e3456

Please sign in to comment.