Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add FIPS-aligned AES specification #29

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ protected theorem zero_le_sub (x y : BitVec n) :

@[simp]
protected theorem zero_or (x : BitVec n) : 0#n ||| x = x := by
unfold HOr.hOr instHOr OrOp.or instOrOpBitVec BitVec.or
unfold HOr.hOr instHOrOfOrOp BitVec.instOrOp BitVec.or
simp only [toNat_ofNat, Nat.or_zero]
congr

Expand Down Expand Up @@ -391,6 +391,26 @@ protected theorem truncate_to_lsb_of_append (m n : Nat) (x : BitVec m) (y : BitV
truncate n (x ++ y) = y := by
simp only [truncate_append, Nat.le_refl, ↓reduceDite, zeroExtend_eq]

@[simp] theorem extractLsb'_cast {k l: Nat} (e: k=l) (lo : Nat) (x : BitVec n):
BitVec.cast e (extractLsb' lo k x) = extractLsb' lo l x := by cases e; simp

@[simp] theorem extractLsb'_cast' {k: Nat} (e: n=m) (lo : Nat) (x : BitVec n):
(extractLsb' lo k (BitVec.cast e x)) = extractLsb' lo k x := by cases e; simp

@[simp] theorem getLsb'_extract (lo k : Nat) (x : BitVec n) (i : Nat) :
getLsb (extractLsb' lo k x) i = (i < k && getLsb x (lo+i)) := by
unfold getLsb
simp [Nat.lt_succ]

@[simp] theorem extractLsb'_of_append {x : BitVec w} {y : BitVec v} :
(x ++ y).extractLsb' v w = x := by
apply eq_of_getLsb_eq
intro i
have k: ¬(v + i) < v := by omega
have k': v + i - v = i := by omega
simp [getLsb'_extract, k, k']
done

----------------------------------------------------------------------

/- Bitvector pattern component syntax category, originally written by
Expand Down
13 changes: 2 additions & 11 deletions Arm/Map.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,8 @@ def Map.size (m : Map α β) : Nat :=
@[simp] theorem Map.size_erase_le [DecidableEq α] (m : Map α β) (a : α) : (m.erase a).size ≤ m.size := by
induction m <;> simp [erase, size] at *
split
next =>
-- (FIXME) This could be discharged by omega in
-- leanprover/lean4:nightly-2024-02-24, but not in
-- leanprover/lean4:nightly-2024-03-01.
exact Nat.le_succ_of_le (by assumption)
next =>
simp;
-- (FIXME) This could be discharged by omega in
-- leanprover/lean4:nightly-2024-02-24, but not in
-- leanprover/lean4:nightly-2024-03-01.
exact Nat.succ_le_succ (by assumption)
next => omega
next => simp; omega

@[simp] theorem Map.size_erase_eq [DecidableEq α] (m : Map α β) (a : α) : m.contains a = false → (m.erase a).size = m.size := by
induction m <;> simp [erase, size] at *
Expand Down
3 changes: 1 addition & 2 deletions Arm/SeparateProofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ theorem n_minus_1_lt_2_64_1 (n : Nat)
refine BitVec.val_bitvec_lt.mp ?a
simp [BitVec.bitvec_to_nat_of_nat]
have : n - 1 < 2 ^ 64 := by omega
simp_all [Nat.mod_eq_of_lt]
exact Nat.sub_lt_left_of_lt_add h1 h2
omega

-- (FIXME) Prove for all bitvector widths.
theorem BitVec.add_sub_self_left_64 (a m : BitVec 64) :
Expand Down
60 changes: 60 additions & 0 deletions Arm/Vec.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import Lean
import Tactics.Elim
import Init.Data.List.Lemmas

open Lean Meta Elab.Tactic

/- Ad-hoc definitions and lemmas about fixed-length lists

MathLib has `Vec` which could be used instead; for now, we use our own
version to keep dependencies down.

-/

abbrev Vec α (n : Nat) := { v : List α // v.length = n }

@[simp]
theorem Vec_eq_transport {k l : Nat} (h : k = l) (x : Vec α k) : (h▸x).val = x.val := by cases h; simp

@[simp] def Vec.inBounds (_ : Vec α n) (i : Nat) : Prop := i < n

@[simp] def Vec.get {n : Nat} (v : Vec α n) (i : Nat) (ok : v.inBounds i) : α :=
let hi: i < v.val.length := by simpa [v.property] using ok
v.val[i]'hi

def Vec.empty : Vec α 0 := ⟨[], by simp⟩

def Vec.cons {n : Nat} (x : α) (v : Vec α n) : Vec α (n+1) :=
⟨List.cons x v.val, by simp [v.property]⟩

theorem Vec.ext'' (x y : Vec α n) (h: x.val = y.val) : x = y := by
cases x <;> cases y; simp_all

@[simp] def Vec.append {n m : Nat} (v : Vec α n) (w : Vec α m) : Vec α (n + m) :=
⟨v.val ++ w.val, by simp [v.property, w.property]⟩

instance : HAppend (Vec α n) (Vec α m) (Vec α (n+m)) where
hAppend xs ys := Vec.append xs ys

@[simp] def Vec.push {n : Nat} (v : Vec α n) (x : α): Vec α (n+1) :=
⟨v.val ++ [x], by simp [v.property]⟩

-- Support array-like access st[i]
@[simp] instance GetElem_Vec : GetElem (Vec α n) Nat α Vec.inBounds where
getElem := Vec.get

-- Extensionality for fixed-length lists
@[ext] def Vec.ext {n : Nat} (v w : Vec α n) (h: ∀(i : Nat), (h : i < n) → v[i] = w[i]) : v = w := by
apply Subtype.eq
apply List.ext_get <;> simp [v.property, w.property]
intros
simp_all [getElem, GetElem_Vec, Vec.get]
done

def Vec.ext' {n : Nat} (v w : Vec α n) : (v=w) <-> (∀(i : Nat), (h : i < n) → v[i] = w[i]) := by
apply Iff.intro
· simp_all
· intros
ext
simp_all
done
182 changes: 182 additions & 0 deletions Specs/AESSpec.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Hanno Becker
-/

import Arm.BitVec
import Arm.Vec

import Lean
open Lean Meta

open BitVec

/-
The goal of this file is to provide a model of the AES specification
that is as close as possible to the original FIPS document.

TODO: This is work in progress. For the moment, we only have definitions
of the AES state as a bitvector, byte vector, or byte array, as well as
conversions between them.
-/

/-
3.3 Indexing of Byte Sequences
-/

abbrev Byte : Type := BitVec 8

def bitvec_split_byte {k : Nat} (invec : BitVec (8*k)) : Byte × (BitVec (8*(k-1))) :=
let byte : BitVec 8 := invec.truncate 8
let remainder := BitVec.extractLsb' 8 (8*(k-1)) invec
(byte, remainder)

theorem bitvec_split_byte_append {k : Nat} (v : BitVec (8*(k-1))) (b : Byte) (h : 8 * (k-1) + 8 = 8 * k) :
bitvec_split_byte (BitVec.cast h (v ++ b)) = (b, v) := by simp [bitvec_split_byte]

theorem bitvec_split_byte_append' {k : Nat} (v : BitVec (8*(k+1))):
let (b, v') := bitvec_split_byte v
v' ++ b = v := by
apply eq_of_getLsb_eq
intro i
simp [getLsb_append]
by_cases h: ((i : Nat) < 8)
· simp [h]
· simp [h]
have lt: (i : Nat) - 8 < 8*k := by simp_all; omega
have e : (8 + ((i : Nat) - 8)) = i := by simp_all
simp [e, lt]
done
done

-- Splitting a bitvector of 8*k entries into bytes, little endian
def bitvec_to_byte_seq (k : Nat) (invec: BitVec (8*k)) : Vec Byte k :=
if k_gt_0: (k > 0) then
let (byte, remainder) := bitvec_split_byte invec
have h: k - 1 + 1 = k := by omega
h ▸ Vec.cons byte (bitvec_to_byte_seq (k-1) remainder)
else
have h: k = 0 := by omega
h ▸ Vec.empty

-- Concatenating a little endian byte sequence into a bitvector
def byte_seq_to_bitvec: Vec Byte k → BitVec (8*k)
| ⟨[], h⟩ =>
let h' : 0 = 8*k := by simp_all; omega
BitVec.cast h' (BitVec.ofNat 0 0)
| ⟨a :: as, h⟩ =>
let g : 8 * (k - 1) + 8 = 8 * k := by simp [←h]; omega
BitVec.cast g (byte_seq_to_bitvec ⟨as, by simp[←h]⟩ ++ a)

@[simp]
theorem byte_seq_to_bitvec_cons:
byte_seq_to_bitvec (Vec.cons byte v) = (byte_seq_to_bitvec v) ++ byte := by
cases v; simp only [Vec.cons, byte_seq_to_bitvec]; rfl

-- Example
def example_bitvec : BitVec 128 := 0x0102030405060708090a0b0c0d0e0f#128
def example_byteseq := bitvec_to_byte_seq 16 example_bitvec
def example_bitvec' := byte_seq_to_bitvec example_byteseq
def example_byteseq' := bitvec_to_byte_seq 16 example_bitvec'

#eval example_bitvec
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend using example to test these out instead of using #eval. E.g., see

example : decode_raw_inst 0x91000421#32 =
.

#eval example_bitvec'
#eval example_byteseq
#eval example_byteseq'

@[simp]
def Vec_0_singleton {k : Nat}: (x y : Vec α k) → k = 0 → x = y
| ⟨[], _⟩, ⟨[], _⟩, _ => by simp
| ⟨_ :: _, hx⟩, _, h => by simp at hx; omega
| _, ⟨_ :: _, hy⟩, h => by simp at hy; omega

theorem bitvec_to_byte_seq_invA: ∀(x : Vec Byte k), bitvec_to_byte_seq k (byte_seq_to_bitvec x) = x
| ⟨[], h⟩ => by
have h: k = 0 := by simp_all
rw [byte_seq_to_bitvec]
rw [bitvec_to_byte_seq]
simp_all [Vec.empty]
apply Vec_0_singleton
assumption
| ⟨x :: xs, h⟩ => by
have k_gt_0: k > 0 := by simp_all; omega
let h' : xs.length = k - 1 := by simp_all; omega
let x' : Vec Byte (k-1) := ⟨xs, h'⟩
let h : bitvec_to_byte_seq (k - 1) (byte_seq_to_bitvec x') = x' :=
bitvec_to_byte_seq_invA x'
rw [byte_seq_to_bitvec, bitvec_to_byte_seq]
simp_all [bitvec_split_byte_append, Vec.cons]
apply Vec.ext''; simp
done

theorem bitvec_to_byte_seq_invB: ∀(k : Nat) (x : BitVec (8*k)), byte_seq_to_bitvec (bitvec_to_byte_seq k x) = x
| 0, x => by simp
| k+1, x => by
rw [bitvec_to_byte_seq]
have kp1_gt_0: k + 1 > 0 := by simp
simp [kp1_gt_0]
let x' := (bitvec_split_byte x).snd
have t' : byte_seq_to_bitvec (bitvec_to_byte_seq (k + 1 - 1) x') = x' := by
exact (bitvec_to_byte_seq_invB k _)
simp only [t', bitvec_split_byte_append' x]

/-
3.4 The state
-/

/- We deal with three different presentations of the AES state:
1/ As a bitvector of length 128
2/ As a byte vector length 16
3/ As a 4x4 grid of bytes
-/

/-- Length 16 vectors <-> 4x4 arrays -/

abbrev AESStateBitVec := BitVec 128
abbrev AESStateVec := Vec Byte 16
abbrev AESStateArr := Vec (Vec Byte 4) 4

def AES_State_BitVec_to_Vec (x: AESStateBitVec): AESStateVec := bitvec_to_byte_seq 16 x

def AESStateVec_to_Arr (x : AESStateVec): AESStateArr :=
⟨[⟨[x[0], x[4], x[8] , x[12]], by simp⟩,
⟨[x[1], x[5], x[9] , x[13]], by simp⟩,
⟨[x[2], x[6], x[10], x[14]], by simp⟩,
⟨[x[3], x[7], x[11], x[15]], by simp⟩], by simp⟩

def AESStateArr_to_Vec (x : AESStateArr) : AESStateVec :=
⟨[ x[0][0], x[1][0], x[2][0], x[3][0],
x[0][1], x[1][1], x[2][1], x[3][1],
x[0][2], x[1][2], x[2][2], x[3][2],
x[0][3], x[1][3], x[2][3], x[3][3] ], by simp⟩

theorem lt_succ_iff_lt_or_eq {n m : Nat} (h : 0 < m) : (n < m) <-> (n < m-1 ∨ n = m-1) := by
have hs: m = Nat.succ (m - 1) := by omega
rw [hs, Nat.lt_succ_iff_lt_or_eq]; simp
done

theorem AESStateVec_to_Arr' (x : AESStateVec) (i j : Nat) (hi: i < 4) (hj: j < 4):
have h: 4*j + i < 16 := by omega
(AESStateVec_to_Arr x)[i][j] = x[4*j + i] := by
simp [lt_succ_iff_lt_or_eq, Nat.not_lt_zero] at hi
simp [lt_succ_iff_lt_or_eq, Nat.not_lt_zero] at hj
elim Or.elim <;> try simp_all <;> simp [AESStateVec_to_Arr, Vec.get, GetElem.getElem, GetElem_Vec, List.get]
done

def AESStateArr_to_Vec_invA (x : AESStateVec): (AESStateArr_to_Vec (AESStateVec_to_Arr x)) = x := by
simp [AESStateArr_to_Vec, AESStateVec_to_Arr']
ext
rename_i i hi
simp [lt_succ_iff_lt_or_eq, Nat.not_lt_zero] at hi
elim Or.elim <;> try simp_all <;> simp [AESStateVec_to_Arr, Vec.get, GetElem.getElem, GetElem_Vec, List.get]
done

def AESStateArr_to_Vec_invB (x : AESStateArr): (AESStateVec_to_Arr (AESStateArr_to_Vec x)) = x := by
simp [AESStateArr_to_Vec, AESStateVec_to_Arr']
ext
rename_i i hi j hj
simp [lt_succ_iff_lt_or_eq, Nat.not_lt_zero] at hi
simp [lt_succ_iff_lt_or_eq, Nat.not_lt_zero] at hj
elim Or.elim <;> try simp_all <;> simp [AESStateVec_to_Arr, Vec.get, GetElem.getElem, GetElem_Vec, List.get]
done
Loading
Loading