Skip to content

Commit

Permalink
feat: alignment of Array.set lemmas with List lemmas (#6365)
Browse files Browse the repository at this point in the history
This PR expands the `Array.set` and `Array.setIfInBounds` lemmas to
match existing lemmas for `List.set`.
  • Loading branch information
kim-em authored Dec 11, 2024
1 parent cd909b0 commit c83ce02
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 93 deletions.
268 changes: 175 additions & 93 deletions src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ import Init.Data.List.ToArray
## Theorems about `Array`.
-/



namespace Array

/-! ## Preliminaries -/
Expand Down Expand Up @@ -184,6 +182,14 @@ theorem getElem_eq_getElem?_get (a : Array α) (i : Nat) (h : i < a.size) :
a[i] = a[i]?.get (by simp [getElem?_eq_getElem, h]) := by
simp [getElem_eq_iff]

theorem getD_getElem? (a : Array α) (i : Nat) (d : α) :
a[i]?.getD d = if p : i < a.size then a[i]'p else d := by
if h : i < a.size then
simp [h, getElem?_def]
else
have p : i ≥ a.size := Nat.le_of_not_gt h
simp [getElem?_eq_none p, h]

@[simp] theorem getElem?_empty {n : Nat} : (#[] : Array α)[n]? = none := rfl

theorem getElem_push_lt (a : Array α) (x : α) (i : Nat) (h : i < a.size) :
Expand Down Expand Up @@ -710,6 +716,164 @@ theorem all_push [BEq α] {as : Array α} {a : α} {p : α → Bool} :
(l.push a).contains b = (l.contains b || b == a) := by
simp [contains]

/-! ### set -/


/-! # set -/

@[simp] theorem getElem_set_self (a : Array α) (i : Nat) (h : i < a.size) (v : α) {j : Nat}
(eq : i = j) (p : j < (a.set i v).size) :
(a.set i v)[j]'p = v := by
cases a
simp
simp [set, ← getElem_toList, ←eq]

@[deprecated getElem_set_self (since := "2024-12-11")]
abbrev getElem_set_eq := @getElem_set_self

@[simp] theorem getElem?_set_self (a : Array α) (i : Nat) (h : i < a.size) (v : α) :
(a.set i v)[i]? = v := by simp [getElem?_eq_getElem, h]

@[deprecated getElem?_set_self (since := "2024-12-11")]
abbrev getElem?_set_eq := @getElem?_set_self

@[simp] theorem getElem_set_ne (a : Array α) (i : Nat) (h' : i < a.size) (v : α) {j : Nat}
(pj : j < (a.set i v).size) (h : i ≠ j) :
(a.set i v)[j]'pj = a[j]'(size_set a i v _ ▸ pj) := by
simp only [set, ← getElem_toList, List.getElem_set_ne h]

@[simp] theorem getElem?_set_ne (a : Array α) (i : Nat) (h : i < a.size) {j : Nat} (v : α)
(ne : i ≠ j) : (a.set i v)[j]? = a[j]? := by
by_cases h : j < a.size <;> simp [getElem?_eq_getElem, getElem?_eq_none, Nat.ge_of_not_lt, ne, h]

theorem getElem_set (a : Array α) (i : Nat) (h' : i < a.size) (v : α) (j : Nat)
(h : j < (a.set i v).size) :
(a.set i v)[j]'h = if i = j then v else a[j]'(size_set a i v _ ▸ h) := by
by_cases p : i = j <;> simp [p]

theorem getElem?_set (a : Array α) (i : Nat) (h : i < a.size) (v : α) (j : Nat) :
(a.set i v)[j]? = if i = j then some v else a[j]? := by
split <;> simp_all

@[simp] theorem set_getElem_self {as : Array α} {i : Nat} (h : i < as.size) :
as.set i as[i] = as := by
cases as
simp

@[simp] theorem set_eq_empty_iff {as : Array α} (n : Nat) (a : α) (h) :
as.set n a = #[] ↔ as = #[] := by
cases as <;> cases n <;> simp [set]

theorem set_comm (a b : α)
{n m : Nat} (as : Array α) {hn : n < as.size} {hm : m < (as.set n a).size} (h : n ≠ m) :
(as.set n a).set m b = (as.set m b (by simpa using hm)).set n a (by simpa using hn) := by
cases as
simp [List.set_comm _ _ _ h]

@[simp]
theorem set_set (a b : α) (as : Array α) (n : Nat) (h : n < as.size) :
(as.set n a).set n b (by simpa using h) = as.set n b := by
cases as
simp

theorem mem_set (as : Array α) (n : Nat) (h : n < as.size) (a : α) :
a ∈ as.set n a := by
simp [mem_iff_getElem]
exact ⟨n, (by simpa using h), by simp⟩

theorem mem_or_eq_of_mem_set
{as : Array α} {n : Nat} {a b : α} {w : n < as.size} (h : a ∈ as.set n b) : a ∈ as ∨ a = b := by
cases as
simpa using List.mem_or_eq_of_mem_set (by simpa using h)

/-! # setIfInBounds -/

@[simp] theorem set!_is_setIfInBounds : @set! = @setIfInBounds := rfl

@[simp] theorem size_setIfInBounds (as : Array α) (index : Nat) (val : α) :
(as.setIfInBounds index val).size = as.size := by
if h : index < as.size then
simp [setIfInBounds, h]
else
simp [setIfInBounds, h]

theorem getElem_setIfInBounds (as : Array α) (i : Nat) (v : α) (j : Nat)
(hj : j < (as.setIfInBounds i v).size) :
(as.setIfInBounds i v)[j]'hj = if i = j then v else as[j]'(by simpa using hj) := by
simp only [setIfInBounds]
split
· simp [getElem_set]
· simp only [size_setIfInBounds] at hj
rw [if_neg]
omega

@[simp] theorem getElem_setIfInBounds_self (as : Array α) {i : Nat} (v : α) (h : _) :
(as.setIfInBounds i v)[i]'h = v := by
simp at h
simp only [setIfInBounds, h, ↓reduceDIte, getElem_set_self]

@[deprecated getElem_setIfInBounds_self (since := "2024-12-11")]
abbrev getElem_setIfInBounds_eq := @getElem_setIfInBounds_self

@[simp] theorem getElem_setIfInBounds_ne (as : Array α) {i : Nat} (v : α) {j : Nat}
(hj : j < (as.setIfInBounds i v).size) (h : i ≠ j) :
(as.setIfInBounds i v)[j]'hj = as[j]'(by simpa using hj) := by
simp [getElem_setIfInBounds, h]

@[simp]
theorem getElem?_setIfInBounds_self (as : Array α) {i : Nat} (p : i < as.size) (v : α) :
(as.setIfInBounds i v)[i]? = some v := by
simp [getElem?_eq_getElem, p]

@[deprecated getElem?_setIfInBounds_self (since := "2024-12-11")]
abbrev getElem?_setIfInBounds_eq := @getElem?_setIfInBounds_self

theorem getElem?_setIfInBounds {as : Array α} {i j : Nat} {a : α} :
(as.setIfInBounds i a)[j]? = if i = j then if i < as.size then some a else none else as[j]? := by
cases as
simp [List.getElem?_set]

@[simp] theorem getElem?_setIfInBounds_ne {as : Array α} {i j : Nat} (h : i ≠ j) {a : α} :
(as.setIfInBounds i a)[j]? = as[j]? := by
simp [getElem?_setIfInBounds, h]

theorem setIfInBounds_eq_of_size_le {l : Array α} {n : Nat} (h : l.size ≤ n) {a : α} :
l.setIfInBounds n a = l := by
cases l
simp [List.set_eq_of_length_le (by simpa using h)]

@[simp] theorem setIfInBounds_eq_empty_iff {as : Array α} (n : Nat) (a : α) :
as.setIfInBounds n a = #[] ↔ as = #[] := by
cases as <;> cases n <;> simp

theorem setIfInBounds_comm (a b : α)
{n m : Nat} (as : Array α) (h : n ≠ m) :
(as.setIfInBounds n a).setIfInBounds m b = (as.setIfInBounds m b).setIfInBounds n a := by
cases as
simp [List.set_comm _ _ _ h]

@[simp]
theorem setIfInBounds_setIfInBounds (a b : α) (as : Array α) (n : Nat) :
(as.setIfInBounds n a).setIfInBounds n b = as.setIfInBounds n b := by
cases as
simp

theorem mem_setIfInBounds (as : Array α) (n : Nat) (h : n < as.size) (a : α) :
a ∈ as.setIfInBounds n a := by
simp [mem_iff_getElem]
exact ⟨n, (by simpa using h), by simp⟩

theorem mem_or_eq_of_mem_setIfInBounds
{as : Array α} {n : Nat} {a b : α} (h : a ∈ as.setIfInBounds n b) : a ∈ as ∨ a = b := by
cases as
simpa using List.mem_or_eq_of_mem_set (by simpa using h)

/-- Simplifies a normal form from `get!` -/
@[simp] theorem getD_get?_setIfInBounds (a : Array α) (i : Nat) (v d : α) :
(setIfInBounds a i v)[i]?.getD d = if i < a.size then v else d := by
by_cases h : i < a.size <;>
simp [setIfInBounds, Nat.not_lt_of_le, h, getD_getElem?]

theorem singleton_inj : #[a] = #[b] ↔ a = b := by
simp

Expand Down Expand Up @@ -820,102 +984,31 @@ theorem size_uset (a : Array α) (v i h) : (uset a i v h).size = a.size := by si

/-! # get -/

@[deprecated getElem?_eq_getElem (since := "2024-12-11")]
theorem getElem?_lt
(a : Array α) {i : Nat} (h : i < a.size) : a[i]? = some a[i] := dif_pos h

@[deprecated getElem?_eq_none (since := "2024-12-11")]
theorem getElem?_ge
(a : Array α) {i : Nat} (h : i ≥ a.size) : a[i]? = none := dif_neg (Nat.not_lt_of_le h)

@[simp] theorem get?_eq_getElem? (a : Array α) (i : Nat) : a.get? i = a[i]? := rfl

@[deprecated getElem?_eq_none (since := "2024-12-11")]
theorem getElem?_len_le (a : Array α) {i : Nat} (h : a.size ≤ i) : a[i]? = none := by
simp [getElem?_ge, h]
simp [getElem?_eq_none, h]

theorem getD_get? (a : Array α) (i : Nat) (d : α) :
Option.getD a[i]? d = if p : i < a.size then a[i]'p else d := by
if h : i < a.size then
simp [setIfInBounds, h, getElem?_def]
else
have p : i ≥ a.size := Nat.le_of_not_gt h
simp [setIfInBounds, getElem?_len_le _ p, h]
@[deprecated getD_getElem? (since := "2024-12-11")] abbrev getD_get? := @getD_getElem?

@[simp] theorem getD_eq_get? (a : Array α) (n d) : a.getD n d = (a[n]?).getD d := by
simp only [getD, get_eq_getElem, get?_eq_getElem?]; split <;> simp [getD_get?, *]
simp only [getD, get_eq_getElem, get?_eq_getElem?]; split <;> simp [getD_getElem?, *]

theorem get!_eq_getD [Inhabited α] (a : Array α) : a.get! n = a.getD n default := rfl

@[simp] theorem get!_eq_getElem? [Inhabited α] (a : Array α) (i : Nat) :
a.get! i = (a.get? i).getD default := by
by_cases p : i < a.size <;>
simp only [get!_eq_getD, getD_eq_get?, getD_get?, p, get?_eq_getElem?]

/-! # set -/

@[simp] theorem getElem_set_eq (a : Array α) (i : Nat) (h : i < a.size) (v : α) {j : Nat}
(eq : i = j) (p : j < (a.set i v).size) :
(a.set i v)[j]'p = v := by
cases a
simp
simp [set, ← getElem_toList, ←eq]

@[simp] theorem getElem_set_ne (a : Array α) (i : Nat) (h' : i < a.size) (v : α) {j : Nat}
(pj : j < (a.set i v).size) (h : i ≠ j) :
(a.set i v)[j]'pj = a[j]'(size_set a i v _ ▸ pj) := by
simp only [set, ← getElem_toList, List.getElem_set_ne h]

theorem getElem_set (a : Array α) (i : Nat) (h' : i < a.size) (v : α) (j : Nat)
(h : j < (a.set i v).size) :
(a.set i v)[j]'h = if i = j then v else a[j]'(size_set a i v _ ▸ h) := by
by_cases p : i = j <;> simp [p]

@[simp] theorem getElem?_set_eq (a : Array α) (i : Nat) (h : i < a.size) (v : α) :
(a.set i v)[i]? = v := by simp [getElem?_lt, h]

@[simp] theorem getElem?_set_ne (a : Array α) (i : Nat) (h : i < a.size) {j : Nat} (v : α)
(ne : i ≠ j) : (a.set i v)[j]? = a[j]? := by
by_cases h : j < a.size <;> simp [getElem?_lt, getElem?_ge, Nat.ge_of_not_lt, ne, h]

/-! # setIfInBounds -/

@[simp] theorem set!_is_setIfInBounds : @set! = @setIfInBounds := rfl

@[simp] theorem size_setIfInBounds (a : Array α) (index : Nat) (val : α) :
(Array.setIfInBounds a index val).size = a.size := by
if h : index < a.size then
simp [setIfInBounds, h]
else
simp [setIfInBounds, h]

theorem getElem_setIfInBounds (a : Array α) (i : Nat) (v : α) (j : Nat)
(hj : j < (setIfInBounds a i v).size) :
(setIfInBounds a i v)[j]'hj = if i = j then v else a[j]'(by simpa using hj) := by
simp only [setIfInBounds]
split
· simp [getElem_set]
· simp only [size_setIfInBounds] at hj
rw [if_neg]
omega

@[simp] theorem getElem_setIfInBounds_eq (a : Array α) {i : Nat} (v : α) (h : _) :
(setIfInBounds a i v)[i]'h = v := by
simp at h
simp only [setIfInBounds, h, ↓reduceDIte, getElem_set_eq]

@[simp] theorem getElem_setIfInBounds_ne (a : Array α) {i : Nat} (v : α) {j : Nat}
(hj : j < (setIfInBounds a i v).size) (h : i ≠ j) :
(setIfInBounds a i v)[j]'hj = a[j]'(by simpa using hj) := by
simp [getElem_setIfInBounds, h]

@[simp]
theorem getElem?_setIfInBounds_eq (a : Array α) {i : Nat} (p : i < a.size) (v : α) :
(a.setIfInBounds i v)[i]? = some v := by
simp [getElem?_lt, p]

/-- Simplifies a normal form from `get!` -/
@[simp] theorem getD_get?_setIfInBounds (a : Array α) (i : Nat) (v d : α) :
Option.getD (setIfInBounds a i v)[i]? d = if i < a.size then v else d := by
by_cases h : i < a.size <;>
simp [setIfInBounds, Nat.not_lt_of_le, h, getD_get?]
simp only [get!_eq_getD, getD_eq_get?, getD_getElem?, p, get?_eq_getElem?]

/-! # ofFn -/

Expand Down Expand Up @@ -1092,9 +1185,6 @@ theorem get_set (a : Array α) (i : Nat) (hi : i < a.size) (j : Nat) (hj : j < a
(h : i ≠ j) : (a.set i v)[j]'(by simp [*]) = a[j] := by
simp only [set, ← getElem_toList, List.getElem_set_ne h]

theorem set_set (a : Array α) (i : Nat) (h) (v v' : α) :
(a.set i v h).set i v' (by simp [h]) = a.set i v' := by simp [set, List.set_set]

private theorem fin_cast_val (e : n = n') (i : Fin n) : e ▸ i = ⟨i.1, e ▸ i.2⟩ := by cases e; rfl

theorem swap_def (a : Array α) (i j : Nat) (hi hj) :
Expand Down Expand Up @@ -1998,14 +2088,6 @@ Our goal is to have `simp` "pull `List.toArray` outwards" as much as possible.
apply ext'
simp

@[simp] theorem setIfInBounds_toArray (l : List α) (i : Nat) (a : α) :
l.toArray.setIfInBounds i a = (l.set i a).toArray := by
apply ext'
simp only [setIfInBounds]
split
· simp
· simp_all [List.set_eq_of_length_le]

@[simp] theorem swap_toArray (l : List α) (i j : Nat) {hi hj}:
l.toArray.swap i j hi hj = ((l.set i l[j]).set j l[i]).toArray := by
apply ext'
Expand Down Expand Up @@ -2324,8 +2406,8 @@ abbrev eq_push_pop_back_of_size_ne_zero := @eq_push_pop_back!_of_size_ne_zero

@[deprecated set!_is_setIfInBounds (since := "2024-11-24")] abbrev set_is_setIfInBounds := @set!_is_setIfInBounds
@[deprecated size_setIfInBounds (since := "2024-11-24")] abbrev size_setD := @size_setIfInBounds
@[deprecated getElem_setIfInBounds_eq (since := "2024-11-24")] abbrev getElem_setD_eq := @getElem_setIfInBounds_eq
@[deprecated getElem?_setIfInBounds_eq (since := "2024-11-24")] abbrev get?_setD_eq := @getElem?_setIfInBounds_eq
@[deprecated getElem_setIfInBounds_eq (since := "2024-11-24")] abbrev getElem_setD_eq := @getElem_setIfInBounds_self
@[deprecated getElem?_setIfInBounds_eq (since := "2024-11-24")] abbrev get?_setD_eq := @getElem?_setIfInBounds_self
@[deprecated getD_get?_setIfInBounds (since := "2024-11-24")] abbrev getD_setD := @getD_get?_setIfInBounds
@[deprecated getElem_setIfInBounds (since := "2024-11-24")] abbrev getElem_setD := @getElem_setIfInBounds

Expand Down
8 changes: 8 additions & 0 deletions src/Init/Data/List/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ theorem getElem_eq_getElem?_get (l : List α) (i : Nat) (h : i < l.length) :
l[i] = l[i]?.get (by simp [getElem?_eq_getElem, h]) := by
simp [getElem_eq_iff]

theorem getD_getElem? (l : List α) (i : Nat) (d : α) :
l[i]?.getD d = if p : i < l.length then l[i]'p else d := by
if h : i < l.length then
simp [h, getElem?_def]
else
have p : i ≥ l.length := Nat.le_of_not_gt h
simp [getElem?_eq_none p, h]

@[simp] theorem getElem?_nil {n : Nat} : ([] : List α)[n]? = none := rfl

theorem getElem_cons {l : List α} (w : i < (a :: l).length) :
Expand Down
8 changes: 8 additions & 0 deletions src/Init/Data/List/ToArray.lean
Original file line number Diff line number Diff line change
Expand Up @@ -363,4 +363,12 @@ theorem takeWhile_go_toArray (p : α → Bool) (l : List α) (i : Nat) :
l.toArray.takeWhile p = (l.takeWhile p).toArray := by
simp [Array.takeWhile, takeWhile_go_toArray]

@[simp] theorem setIfInBounds_toArray (l : List α) (i : Nat) (a : α) :
l.toArray.setIfInBounds i a = (l.set i a).toArray := by
apply ext'
simp only [setIfInBounds]
split
· simp
· simp_all [List.set_eq_of_length_le]

end List

0 comments on commit c83ce02

Please sign in to comment.