Skip to content

Commit

Permalink
feat: verify keys method on HashMaps (#5866)
Browse files Browse the repository at this point in the history
This PR verifies the `keys` function on `Std.HashMap`.

---

Initial discussions have already happend with @TwoFX and we are
collaborating on this matter.
This will remain a draft as long as not all desired results have been
added.

If we should still create an issue for the topic of this PR, let us
know.
Of course, any other feedback is appreciated as well :)

---------

Co-authored-by: Markus Himmel <[email protected]>
Co-authored-by: monsterkrampe <[email protected]>
Co-authored-by: jt0202 <[email protected]>
  • Loading branch information
4 people authored Nov 8, 2024
1 parent 1870c00 commit 9b167e2
Show file tree
Hide file tree
Showing 18 changed files with 270 additions and 20 deletions.
4 changes: 4 additions & 0 deletions src/Init/Data/List/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2881,6 +2881,10 @@ theorem contains_iff_exists_mem_beq [BEq α] {l : List α} {a : α} :
l.contains a ↔ ∃ a' ∈ l, a == a' := by
induction l <;> simp_all

theorem contains_iff_mem [BEq α] [LawfulBEq α] {l : List α} {a : α} :
l.contains a ↔ a ∈ l := by
simp

/-! ## Sublists -/

/-! ### partition
Expand Down
8 changes: 8 additions & 0 deletions src/Init/Data/List/Perm.lean
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ theorem Perm.length_eq {l₁ l₂ : List α} (p : l₁ ~ l₂) : length l₁ = l
| swap => rfl
| trans _ _ ih₁ ih₂ => simp only [ih₁, ih₂]

theorem Perm.contains_eq [BEq α] {l₁ l₂ : List α} (h : l₁ ~ l₂) {a : α} :
l₁.contains a = l₂.contains a := by
induction h with
| nil => rfl
| cons => simp_all
| swap => simp only [contains_cons, ← Bool.or_assoc, Bool.or_comm]
| trans => simp_all

theorem Perm.eq_nil {l : List α} (p : l ~ []) : l = [] := eq_nil_of_length_eq_zero p.length_eq

theorem Perm.nil_eq {l : List α} (p : [] ~ l) : [] = l := p.symm.eq_nil.symm
Expand Down
6 changes: 3 additions & 3 deletions src/Std/Data/DHashMap/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ end
@[inline, inherit_doc Raw.isEmpty] def isEmpty (m : DHashMap α β) : Bool :=
m.1.isEmpty

@[inline, inherit_doc Raw.keys] def keys (m : DHashMap α β) : List α :=
m.1.keys

section Unverified

/-! We currently do not provide lemmas for the functions below. -/
Expand Down Expand Up @@ -232,9 +235,6 @@ instance [BEq α] [Hashable α] : ForIn m (DHashMap α β) ((a : α) × β a) wh
(m : DHashMap α (fun _ => β)) : Array (α × β) :=
Raw.Const.toArray m.1

@[inline, inherit_doc Raw.keys] def keys (m : DHashMap α β) : List α :=
m.1.keys

@[inline, inherit_doc Raw.keysArray] def keysArray (m : DHashMap α β) :
Array α :=
m.1.keysArray
Expand Down
5 changes: 5 additions & 0 deletions src/Std/Data/DHashMap/Internal/AssocList/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,9 @@ theorem toList_filter {f : (a : α) → β a → Bool} {l : AssocList α β} :
· exact (ih _).trans (by simpa using perm_middle.symm)
· exact ih _

theorem foldl_apply {l : AssocList α β} {acc : List δ} (f : (a : α) → β a → δ) :
l.foldl (fun acc k v => f k v :: acc) acc =
(l.toList.map (fun p => f p.1 p.2)).reverse ++ acc := by
induction l generalizing acc <;> simp_all [AssocList.foldl, AssocList.foldlM, Id.run]

end Std.DHashMap.Internal.AssocList
6 changes: 6 additions & 0 deletions src/Std/Data/DHashMap/Internal/List/Associative.lean
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,12 @@ theorem isEmpty_eraseKey [BEq α] {l : List ((a : α) × β a)} {k : α} :
theorem keys_eq_map (l : List ((a : α) × β a)) : keys l = l.map (·.1) := by
induction l using assoc_induction <;> simp_all

theorem length_keys_eq_length (l : List ((a : α) × β a)) : (keys l).length = l.length := by
induction l using assoc_induction <;> simp_all

theorem isEmpty_keys_eq_isEmpty (l : List ((a : α) × β a)) : (keys l).isEmpty = l.isEmpty := by
induction l using assoc_induction <;> simp_all

theorem containsKey_eq_keys_contains [BEq α] [PartialEquivBEq α] {l : List ((a : α) × β a)}
{a : α} : containsKey a l = (keys l).contains a := by
induction l using assoc_induction
Expand Down
30 changes: 29 additions & 1 deletion src/Std/Data/DHashMap/Internal/RawLemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ private def queryNames : Array Name :=
``get?_eq_getValueCast?, ``Const.get?_eq_getValue?, ``get_eq_getValueCast,
``Const.get_eq_getValue, ``get!_eq_getValueCast!, ``getD_eq_getValueCastD,
``Const.get!_eq_getValue!, ``Const.getD_eq_getValueD, ``getKey?_eq_getKey?,
``getKey_eq_getKey, ``getKeyD_eq_getKeyD, ``getKey!_eq_getKey!]
``getKey_eq_getKey, ``getKeyD_eq_getKeyD, ``getKey!_eq_getKey!,
``Raw.length_keys_eq_length_keys, ``Raw.isEmpty_keys_eq_isEmpty_keys,
``Raw.contains_keys_eq_contains_keys, ``Raw.mem_keys_iff_contains_keys,
``Raw.pairwise_keys_iff_pairwise_keys]

private def modifyNames : Array Name :=
#[``toListModel_insert, ``toListModel_erase, ``toListModel_insertIfNew]
Expand Down Expand Up @@ -811,6 +814,31 @@ theorem getThenInsertIfNew?_snd {k : α} {v : β} :

end Const

@[simp]
theorem length_keys [EquivBEq α] [LawfulHashable α] (h : m.1.WF) :
m.1.keys.length = m.1.size := by
simp_to_model using List.length_keys_eq_length

@[simp]
theorem isEmpty_keys [EquivBEq α] [LawfulHashable α] (h: m.1.WF):
m.1.keys.isEmpty = m.1.isEmpty:= by
simp_to_model using List.isEmpty_keys_eq_isEmpty

@[simp]
theorem contains_keys [EquivBEq α] [LawfulHashable α] (h : m.1.WF) {k : α} :
m.1.keys.contains k = m.contains k := by
simp_to_model using List.containsKey_eq_keys_contains.symm

@[simp]
theorem mem_keys [LawfulBEq α] [LawfulHashable α] (h : m.1.WF) {k : α} :
k ∈ m.1.keys ↔ m.contains k := by
simp_to_model
rw [List.containsKey_eq_keys_contains]

theorem distinct_keys [EquivBEq α] [LawfulHashable α] (h : m.1.WF) :
m.1.keys.Pairwise (fun a b => (a == b) = false) := by
simp_to_model using (Raw.WF.out h).distinct.distinct

end Raw₀

end Std.DHashMap.Internal
52 changes: 52 additions & 0 deletions src/Std/Data/DHashMap/Internal/WF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,58 @@ theorem isEmpty_eq_isEmpty [BEq α] [Hashable α] {m : Raw α β} (h : Raw.WFImp
rw [Raw.isEmpty, Bool.eq_iff_iff, List.isEmpty_iff_length_eq_zero, size_eq_length h,
Nat.beq_eq_true_eq]

theorem fold_eq {l : Raw α β} {f : γ → (a : α) → β a → γ} {init : γ} :
l.fold f init = l.buckets.foldl (fun acc l => l.foldl f acc) init := by
simp only [Raw.fold, Raw.foldM, Array.foldlM_eq_foldlM_toList, Array.foldl_eq_foldl_toList,
← List.foldl_eq_foldlM, Id.run, AssocList.foldl]

theorem fold_cons_apply {l : Raw α β} {acc : List γ} (f : (a : α) → β a → γ) :
l.fold (fun acc k v => f k v :: acc) acc =
((toListModel l.buckets).reverse.map (fun p => f p.1 p.2)) ++ acc := by
rw [fold_eq, Array.foldl_eq_foldl_toList, toListModel]
generalize l.buckets.toList = l
induction l generalizing acc with
| nil => simp
| cons x xs ih =>
rw [foldl_cons, ih, AssocList.foldl_apply]
simp

theorem fold_cons {l : Raw α β} {acc : List ((a : α) × β a)} :
l.fold (fun acc k v => ⟨k, v⟩ :: acc) acc = (toListModel l.buckets).reverse ++ acc := by
simp [fold_cons_apply]

theorem fold_cons_key {l : Raw α β} {acc : List α} :
l.fold (fun acc k _ => k :: acc) acc = List.keys (toListModel l.buckets).reverse ++ acc := by
rw [fold_cons_apply, keys_eq_map, map_reverse]

theorem toList_perm_toListModel {m : Raw α β} : Perm m.toList (toListModel m.buckets) := by
simp [Raw.toList, fold_cons]

theorem keys_perm_keys_toListModel {m : Raw α β} :
Perm m.keys (List.keys (toListModel m.buckets)) := by
simp [Raw.keys, fold_cons_key, keys_eq_map]

theorem length_keys_eq_length_keys {m : Raw α β} :
m.keys.length = (List.keys (toListModel m.buckets)).length :=
keys_perm_keys_toListModel.length_eq

theorem isEmpty_keys_eq_isEmpty_keys {m : Raw α β} :
m.keys.isEmpty = (List.keys (toListModel m.buckets)).isEmpty :=
keys_perm_keys_toListModel.isEmpty_eq

theorem contains_keys_eq_contains_keys [BEq α] {m : Raw α β} {k : α} :
m.keys.contains k = (List.keys (toListModel m.buckets)).contains k :=
keys_perm_keys_toListModel.contains_eq

theorem mem_keys_iff_contains_keys [BEq α] [LawfulBEq α] {m : Raw α β} {k : α} :
k ∈ m.keys ↔ (List.keys (toListModel m.buckets)).contains k := by
rw [← List.contains_iff_mem, contains_keys_eq_contains_keys]

theorem pairwise_keys_iff_pairwise_keys [BEq α] [PartialEquivBEq α] {m : Raw α β} :
m.keys.Pairwise (fun a b => (a == b) = false) ↔
(List.keys (toListModel m.buckets)).Pairwise (fun a b => (a == b) = false) :=
keys_perm_keys_toListModel.pairwise_iff BEq.symm_false

end Raw

namespace Raw₀
Expand Down
25 changes: 25 additions & 0 deletions src/Std/Data/DHashMap/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -943,4 +943,29 @@ theorem getThenInsertIfNew?_snd {k : α} {v : β} :

end Const

@[simp]
theorem length_keys [EquivBEq α] [LawfulHashable α] :
m.keys.length = m.size :=
Raw₀.length_keys ⟨m.1, m.2.size_buckets_pos⟩ m.2

@[simp]
theorem isEmpty_keys [EquivBEq α] [LawfulHashable α]:
m.keys.isEmpty = m.isEmpty :=
Raw₀.isEmpty_keys ⟨m.1, m.2.size_buckets_pos⟩ m.2

@[simp]
theorem contains_keys [EquivBEq α] [LawfulHashable α] {k : α} :
m.keys.contains k = m.contains k :=
Raw₀.contains_keys ⟨m.1, _⟩ m.2

@[simp]
theorem mem_keys [LawfulBEq α] [LawfulHashable α] {k : α} :
k ∈ m.keys ↔ k ∈ m := by
rw [mem_iff_contains]
exact Raw₀.mem_keys ⟨m.1, _⟩ m.2

theorem distinct_keys [EquivBEq α] [LawfulHashable α] :
m.keys.Pairwise (fun a b => (a == b) = false) :=
Raw₀.distinct_keys ⟨m.1, m.2.size_buckets_pos⟩ m.2

end Std.DHashMap
8 changes: 4 additions & 4 deletions src/Std/Data/DHashMap/Raw.lean
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,6 @@ instance : ForIn m (Raw α β) ((a : α) × β a) where
Array (α × β) :=
m.fold (fun acc k v => acc.push ⟨k, v⟩) #[]

/-- Returns a list of all keys present in the hash map in some order. -/
@[inline] def keys (m : Raw α β) : List α :=
m.fold (fun acc k _ => k :: acc) []

/-- Returns an array of all keys present in the hash map in some order. -/
@[inline] def keysArray (m : Raw α β) : Array α :=
m.fold (fun acc k _ => acc.push k) #[]
Expand Down Expand Up @@ -447,6 +443,10 @@ instance [Repr α] [(a : α) → Repr (β a)] : Repr (Raw α β) where

end Unverified

/-- Returns a list of all keys present in the hash map in some order. -/
@[inline] def keys (m : Raw α β) : List α :=
m.fold (fun acc k _ => k :: acc) []

section WF

/--
Expand Down
25 changes: 25 additions & 0 deletions src/Std/Data/DHashMap/RawLemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,31 @@ theorem getThenInsertIfNew?_snd (h : m.WF) {k : α} {v : β} :

end Const

@[simp]
theorem length_keys [EquivBEq α] [LawfulHashable α] (h : m.WF) :
m.keys.length = m.size := by
simp_to_raw using Raw₀.length_keys ⟨m, h.size_buckets_pos⟩ h

@[simp]
theorem isEmpty_keys [EquivBEq α] [LawfulHashable α] (h : m.WF):
m.keys.isEmpty = m.isEmpty := by
simp_to_raw using Raw₀.isEmpty_keys ⟨m, h.size_buckets_pos⟩

@[simp]
theorem contains_keys [EquivBEq α] [LawfulHashable α] (h : m.WF) {k : α} :
m.keys.contains k = m.contains k := by
simp_to_raw using Raw₀.contains_keys ⟨m, _⟩ h

@[simp]
theorem mem_keys [LawfulBEq α] [LawfulHashable α] (h : m.WF) {k : α} :
k ∈ m.keys ↔ k ∈ m := by
rw [mem_iff_contains]
simp_to_raw using Raw₀.mem_keys ⟨m, _⟩ h

theorem distinct_keys [EquivBEq α] [LawfulHashable α] (h : m.WF) :
m.keys.Pairwise (fun a b => (a == b) = false) := by
simp_to_raw using Raw₀.distinct_keys ⟨m, h.size_buckets_pos⟩ h

end Raw

end Std.DHashMap
6 changes: 3 additions & 3 deletions src/Std/Data/HashMap/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ instance [BEq α] [Hashable α] : GetElem? (HashMap α β) α β (fun m a => a
@[inline, inherit_doc DHashMap.isEmpty] def isEmpty (m : HashMap α β) : Bool :=
m.inner.isEmpty

@[inline, inherit_doc DHashMap.keys] def keys (m : HashMap α β) : List α :=
m.inner.keys

section Unverified

/-! We currently do not provide lemmas for the functions below. -/
Expand Down Expand Up @@ -231,9 +234,6 @@ instance [BEq α] [Hashable α] {m : Type w → Type w} : ForIn m (HashMap α β
Array (α × β) :=
DHashMap.Const.toArray m.inner

@[inline, inherit_doc DHashMap.keys] def keys (m : HashMap α β) : List α :=
m.inner.keys

@[inline, inherit_doc DHashMap.keysArray] def keysArray (m : HashMap α β) :
Array α :=
m.inner.keysArray
Expand Down
24 changes: 24 additions & 0 deletions src/Std/Data/HashMap/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,30 @@ instance [EquivBEq α] [LawfulHashable α] : LawfulGetElem (HashMap α β) α β
rw [getElem!_eq_get!_getElem?]
split <;> simp_all

@[simp]
theorem length_keys [EquivBEq α] [LawfulHashable α] :
m.keys.length = m.size :=
DHashMap.length_keys

@[simp]
theorem isEmpty_keys [EquivBEq α] [LawfulHashable α]:
m.keys.isEmpty = m.isEmpty :=
DHashMap.isEmpty_keys

@[simp]
theorem contains_keys [EquivBEq α] [LawfulHashable α] {k : α} :
m.keys.contains k = m.contains k :=
DHashMap.contains_keys

@[simp]
theorem mem_keys [LawfulBEq α] [LawfulHashable α] {k : α} :
k ∈ m.keys ↔ k ∈ m :=
DHashMap.mem_keys

theorem distinct_keys [EquivBEq α] [LawfulHashable α] :
m.keys.Pairwise (fun a b => (a == b) = false) :=
DHashMap.distinct_keys

end

end Std.HashMap
6 changes: 3 additions & 3 deletions src/Std/Data/HashMap/Raw.lean
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ instance [BEq α] [Hashable α] : GetElem? (Raw α β) α β (fun m a => a ∈ m
@[inline, inherit_doc DHashMap.Raw.isEmpty] def isEmpty (m : Raw α β) : Bool :=
m.inner.isEmpty

@[inline, inherit_doc DHashMap.Raw.keys] def keys (m : Raw α β) : List α :=
m.inner.keys

section Unverified

/-! We currently do not provide lemmas for the functions below. -/
Expand Down Expand Up @@ -213,9 +216,6 @@ instance {m : Type w → Type w} : ForIn m (Raw α β) (α × β) where
@[inline, inherit_doc DHashMap.Raw.Const.toArray] def toArray (m : Raw α β) : Array (α × β) :=
DHashMap.Raw.Const.toArray m.inner

@[inline, inherit_doc DHashMap.Raw.keys] def keys (m : Raw α β) : List α :=
m.inner.keys

@[inline, inherit_doc DHashMap.Raw.keysArray] def keysArray (m : Raw α β) : Array α :=
m.inner.keysArray

Expand Down
24 changes: 24 additions & 0 deletions src/Std/Data/HashMap/RawLemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,30 @@ theorem getThenInsertIfNew?_snd (h : m.WF) {k : α} {v : β} :
(getThenInsertIfNew? m k v).2 = m.insertIfNew k v :=
ext (DHashMap.Raw.Const.getThenInsertIfNew?_snd h.out)

@[simp]
theorem length_keys [EquivBEq α] [LawfulHashable α] (h : m.WF) :
m.keys.length = m.size :=
DHashMap.Raw.length_keys h.out

@[simp]
theorem isEmpty_keys [EquivBEq α] [LawfulHashable α] (h : m.WF):
m.keys.isEmpty = m.isEmpty :=
DHashMap.Raw.isEmpty_keys h.out

@[simp]
theorem contains_keys [EquivBEq α] [LawfulHashable α] (h : m.WF) {k : α} :
m.keys.contains k = m.contains k :=
DHashMap.Raw.contains_keys h.out

@[simp]
theorem mem_keys [LawfulBEq α] [LawfulHashable α] (h : m.WF) {k : α} :
k ∈ m.keys ↔ k ∈ m :=
DHashMap.Raw.mem_keys h.out

theorem distinct_keys [EquivBEq α] [LawfulHashable α] (h : m.WF) :
m.keys.Pairwise (fun a b => (a == b) = false) :=
DHashMap.Raw.distinct_keys h.out

end Raw

end Std.HashMap
7 changes: 4 additions & 3 deletions src/Std/Data/HashSet/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ for all `a`.
@[inline] def isEmpty (m : HashSet α) : Bool :=
m.inner.isEmpty

/-- Transforms the hash set into a list of elements in some order. -/
@[inline] def toList (m : HashSet α) : List α :=
m.inner.keys

section Unverified

/-! We currently do not provide lemmas for the functions below. -/
Expand Down Expand Up @@ -208,9 +212,6 @@ instance [BEq α] [Hashable α] {m : Type v → Type v} : ForIn m (HashSet α)
if p a then return true
return false

/-- Transforms the hash set into a list of elements in some order. -/
@[inline] def toList (m : HashSet α) : List α :=
m.inner.keys

/-- Transforms the hash set into an array of elements in some order. -/
@[inline] def toArray (m : HashSet α) : Array α :=
Expand Down
Loading

0 comments on commit 9b167e2

Please sign in to comment.