Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: remove @[simp] from Fin.succ_zero_eq_one #6292

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/Init/Data/Fin/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,9 @@ theorem zero_ne_one : (0 : Fin (n + 2)) ≠ 1 := Fin.ne_of_lt one_pos
theorem succ_ne_zero {n} : ∀ k : Fin n, Fin.succ k ≠ 0
| ⟨k, _⟩, heq => Nat.succ_ne_zero k <| congrArg Fin.val heq

@[simp] theorem succ_zero_eq_one : Fin.succ (0 : Fin (n + 1)) = 1 := rfl
theorem succ_zero_eq_one : Fin.succ (0 : Fin (n + 1)) = 1 := rfl

/-- Version of `succ_one_eq_two` to be used by `dsimp` -/
@[simp] theorem succ_one_eq_two : Fin.succ (1 : Fin (n + 2)) = 2 := rfl
theorem succ_one_eq_two : Fin.succ (1 : Fin (n + 2)) = 2 := rfl

@[simp] theorem succ_mk (n i : Nat) (h : i < n) :
Fin.succ ⟨i, h⟩ = ⟨i + 1, Nat.succ_lt_succ h⟩ := rfl
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/LitValues.lean
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def getStringValue? (e : Expr) : (Option String) :=
| .lit (.strVal s) => some s
| _ => none

/-- Return `some ⟨n, v⟩` if `e` is af `OfNat.ofNat` application encoding a `Fin n` with value `v` -/
/-- Return `some ⟨n, v⟩` if `e` is an `OfNat.ofNat` application encoding a `Fin n` with value `v` -/
def getFinValue? (e : Expr) : MetaM (Option ((n : Nat) × Fin n)) := OptionT.run do
let (v, type) ← getOfNatValue? e ``Fin
let n ← getNatValue? (← whnfD type.appArg!)
Expand Down
89 changes: 89 additions & 0 deletions src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Fin.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ def fromExpr? (e : Expr) : SimpM (Option Value) := do
let some ⟨n, value⟩ ← getFinValue? e | return none
return some { n, value }

@[inline] def reduceOp (declName : Name) (arity : Nat) (f : Nat → Nat) (op : {n : Nat} → Fin n → Fin (f n)) (e : Expr) : SimpM DStep := do
unless e.isAppOfArity declName arity do return .continue
let some v ← fromExpr? e.appArg! | return .continue
let v' := op v.value
return .done <| toExpr v'

@[inline] def reduceNatOp (declName : Name) (arity : Nat) (f : Nat → Nat) (op : (n : Nat) → Fin (f n)) (e : Expr) : SimpM DStep := do
unless e.isAppOfArity declName arity do return .continue
let some v ← getNatValue? e.appArg! | return .continue
let v' := op v
return .done <| toExpr v'

@[inline] def reduceBin (declName : Name) (arity : Nat) (op : {n : Nat} → Fin n → Fin n → Fin n) (e : Expr) : SimpM DStep := do
unless e.isAppOfArity declName arity do return .continue
let some v₁ ← fromExpr? e.appFn!.appArg! | return .continue
Expand Down Expand Up @@ -47,12 +59,23 @@ The following code assumes users did not override the `Fin n` instances for the
If they do, they must disable the following `simprocs`.
-/

builtin_dsimproc [simp, seval] reduceSucc (Fin.succ _) := reduceOp ``Fin.succ 2 (· + 1) Fin.succ
builtin_dsimproc [simp, seval] reduceRev (Fin.rev _) := reduceOp ``Fin.rev 2 (·) Fin.rev
builtin_dsimproc [simp, seval] reduceLast (Fin.last _) := reduceNatOp ``Fin.last 1 (· + 1) Fin.last

builtin_dsimproc [simp, seval] reduceAdd ((_ + _ : Fin _)) := reduceBin ``HAdd.hAdd 6 (· + ·)
builtin_dsimproc [simp, seval] reduceMul ((_ * _ : Fin _)) := reduceBin ``HMul.hMul 6 (· * ·)
builtin_dsimproc [simp, seval] reduceSub ((_ - _ : Fin _)) := reduceBin ``HSub.hSub 6 (· - ·)
builtin_dsimproc [simp, seval] reduceDiv ((_ / _ : Fin _)) := reduceBin ``HDiv.hDiv 6 (· / ·)
builtin_dsimproc [simp, seval] reduceMod ((_ % _ : Fin _)) := reduceBin ``HMod.hMod 6 (· % ·)

builtin_dsimproc [simp, seval] reduceAnd ((_ &&& _ : Fin _)) := reduceBin ``HAnd.hAnd 6 (· &&& ·)
builtin_dsimproc [simp, seval] reduceOr ((_ ||| _ : Fin _)) := reduceBin ``HOr.hOr 6 (· ||| ·)
builtin_dsimproc [simp, seval] reduceXor ((_ ^^^ _ : Fin _)) := reduceBin ``HXor.hXor 6 (· ^^^ ·)

builtin_dsimproc [simp, seval] reduceShiftLeft ((_ <<< _ : Fin _)) := reduceBin ``HShiftLeft.hShiftLeft 6 (· <<< ·)
builtin_dsimproc [simp, seval] reduceShiftRight ((_ >>> _ : Fin _)) := reduceBin ``HShiftRight.hShiftRight 6 (· >>> ·)

builtin_simproc [simp, seval] reduceLT (( _ : Fin _) < _) := reduceBinPred ``LT.lt 4 (. < .)
builtin_simproc [simp, seval] reduceLE (( _ : Fin _) ≤ _) := reduceBinPred ``LE.le 4 (. ≤ .)
builtin_simproc [simp, seval] reduceGT (( _ : Fin _) > _) := reduceBinPred ``GT.gt 4 (. > .)
Expand Down Expand Up @@ -83,4 +106,70 @@ builtin_dsimproc [simp, seval] reduceFinMk (Fin.mk _ _) := fun e => do
else
return .continue

builtin_dsimproc [simp, seval] reduceOfNat' (Fin.ofNat' _ _) := fun e => do
unless e.isAppOfArity ``Fin.ofNat' 3 do return .continue
let some (n + 1) ← getNatValue? e.appFn!.appFn!.appArg! | return .continue
let some k ← getNatValue? e.appArg! | return .continue
return .done <| toExpr (Fin.ofNat' (n + 1) k)

builtin_dsimproc [simp, seval] reduceCastSucc (Fin.castSucc _) := fun e => do
unless e.isAppOfArity ``Fin.castSucc 2 do return .continue
let some k ← fromExpr? e.appArg! | return .continue
return .done <| toExpr (castSucc k.value)

builtin_dsimproc [simp, seval] reduceCastAdd (Fin.castAdd _ _) := fun e => do
unless e.isAppOfArity ``Fin.castAdd 3 do return .continue
let some m ← getNatValue? e.appFn!.appArg! | return .continue
let some k ← fromExpr? e.appArg! | return .continue
return .done <| toExpr (castAdd m k.value)

builtin_dsimproc [simp, seval] reduceAddNat (Fin.addNat _ _) := fun e => do
unless e.isAppOfArity ``Fin.addNat 3 do return .continue
let some k ← fromExpr? e.appFn!.appArg! | return .continue
let some m ← getNatValue? e.appArg! | return .continue
return .done <| toExpr (addNat k.value m)

builtin_dsimproc [simp, seval] reduceNatAdd (Fin.natAdd _ _) := fun e => do
unless e.isAppOfArity ``Fin.natAdd 3 do return .continue
let some m ← getNatValue? e.appFn!.appArg! | return .continue
let some k ← fromExpr? e.appArg! | return .continue
return .done <| toExpr (natAdd m k.value)

builtin_dsimproc [simp, seval] reduceCastLT (Fin.castLT _ _) := fun e => do
unless e.isAppOfArity ``Fin.castLT 4 do return .continue
let some n ← getNatValue? e.appFn!.appFn!.appFn!.appArg! | return .continue
let some i ← fromExpr? e.appFn!.appArg! | return .continue
if h : i.value < n then
return .done <| toExpr (castLT i.value h)
else
return .continue

builtin_dsimproc [simp, seval] reduceCastLE (Fin.castLE _ _) := fun e => do
unless e.isAppOfArity ``Fin.castLE 4 do return .continue
let some m ← getNatValue? e.appFn!.appFn!.appArg! | return .continue
let some i ← fromExpr? e.appArg! | return .continue
if h : i.n ≤ m then
return .done <| toExpr (castLE h i.value)
else
return .continue

-- No simproc is needed for `Fin.cast`, as for explicit numbers `Fin.cast_refl` will apply.

builtin_dsimproc [simp, seval] reduceSubNat (Fin.subNat _ _ _) := fun e => do
unless e.isAppOfArity ``Fin.subNat 4 do return .continue
let some m ← getNatValue? e.appFn!.appFn!.appArg! | return .continue
let some i ← fromExpr? e.appFn!.appArg! | return .continue
if h : m ≤ i.value then
return .done <| toExpr (subNat m (i.value.cast (by omega : i.n = (i.n - m) + m)) h)
else
return .continue

builtin_dsimproc [simp, seval] reducePred (Fin.pred _ _) := fun e => do
unless e.isAppOfArity ``Fin.pred 3 do return .continue
let some ⟨(_ + 1), i⟩ ← fromExpr? e.appFn!.appArg! | return .continue
if h : i ≠ 0 then
return .done <| toExpr (pred i h)
else
return .continue

end Fin
75 changes: 75 additions & 0 deletions tests/lean/run/simprocFin.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
variable (n : Nat) [NeZero n]

/- basic operations -/

#check_simp (3 : Fin 7).succ ~> (4 : Fin 8)
#check_simp (6 : Fin 7).succ ~> (7 : Fin 8)
#check_simp Fin.last 0 ~> (0 : Fin 1)
#check_simp Fin.last 6 ~> (6 : Fin 7)
#check_simp Fin.ofNat' 6 3 ~> (3 : Fin 6)
#check_simp Fin.ofNat' 6 37 ~> (1 : Fin 6)
#check_simp Fin.rev (0 : Fin 7) ~> (6 : Fin 7)
#check_simp Fin.rev (3 : Fin 7) ~> (3 : Fin 7)
#check_simp Fin.castSucc (0 : Fin 7) ~> (0 : Fin 8)
#check_simp Fin.castSucc (3 : Fin 7) ~> (3 : Fin 8)
#check_simp Fin.castAdd 3 (0 : Fin 7) ~> (0 : Fin 10)
#check_simp Fin.castAdd 3 (3 : Fin 7) ~> (3 : Fin 10)
#check_simp Fin.castLT (3 : Fin 10) (by decide : 3 < 5) ~> (3 : Fin 5)
#check_simp Fin.castLE (by decide : 5 ≤ 37) (3 : Fin 5) ~> (3 : Fin 37)
#check_simp Fin.pred (3 : Fin 7) (by decide) ~> (2 : Fin 6)

/- arithmetic operation tests -/

#check_simp (3 : Fin 7) + (1 : Fin 7) ~> 4
#check_simp (3 : Fin 7) + (5 : Fin 7) ~> 1
#check_simp (3 : Fin 7) * (1 : Fin 7) ~> 3
#check_simp (3 : Fin 7) * (3 : Fin 7) ~> 2
#check_simp (3 : Fin 7) - (1 : Fin 7) ~> 2
#check_simp (3 : Fin 7) - (5 : Fin 7) ~> 5
#check_simp (3 : Fin 7) / (1 : Fin 7) ~> 3
#check_simp (3 : Fin 7) / (5 : Fin 7) ~> 0
#check_simp (3 : Fin 7) % (0 : Fin 7) ~> 3
#check_simp (3 : Fin 7) % (1 : Fin 7) ~> 0
#check_simp (3 : Fin 7) % (5 : Fin 7) ~> 3

#check_simp (3 : Fin n) + (5 : Fin n) !~>
#check_simp (3 : Fin n) * (5 : Fin n) !~>
#check_simp (3 : Fin n) - (5 : Fin n) !~>
#check_simp (3 : Fin n) / (5 : Fin n) !~>
#check_simp (3 : Fin n) % (5 : Fin n) !~>

#check_simp Fin.addNat (3 : Fin 7) 3 ~> (6 : Fin 10)
#check_simp Fin.natAdd 3 (3 : Fin 7) ~> (6 : Fin 10)
#check_simp Fin.subNat 2 (3 : Fin 7) (by decide) ~> (1 : Fin 5)

/- bitwise operations -/

#check_simp (3 : Fin 7) &&& (1 : Fin 7) ~> 1
#check_simp (3 : Fin 7) ||| (1 : Fin 7) ~> 3
#check_simp (3 : Fin 7) ^^^ (1 : Fin 7) ~> 2
#check_simp (3 : Fin 7) <<< (1 : Fin 7) ~> 6
#check_simp (3 : Fin 7) >>> (1 : Fin 7) ~> 1

/- predicate tests -/

#check_simp (3 : Fin 7) < (1 : Fin 7) ~> False
#check_simp (3 : Fin 7) < (5 : Fin 7) ~> True
#check_simp (3 : Fin 7) ≤ (1 : Fin 7) ~> False
#check_simp (3 : Fin 7) ≤ (5 : Fin 7) ~> True
#check_simp (3 : Fin 7) > (1 : Fin 7) ~> True
#check_simp (3 : Fin 7) > (5 : Fin 7) ~> False
#check_simp (3 : Fin 7) ≥ (1 : Fin 7) ~> True
#check_simp (3 : Fin 7) ≥ (5 : Fin 7) ~> False
#check_simp (3 : Fin 7) = (1 : Fin 7) ~> False
#check_simp (3 : Fin 7) = (5 : Fin 7) ~> False
#check_simp (3 : Fin 7) = (3 : Fin 7) ~> True
#check_simp (3 : Fin 7) ≠ (1 : Fin 7) ~> True
#check_simp (3 : Fin 7) ≠ (3 : Fin 7) ~> False
#check_simp (3 : Fin 7) ≠ (5 : Fin 7) ~> True

#check_simp (3 : Fin 7) == (1 : Fin 7) ~> false
#check_simp (3 : Fin 7) == (3 : Fin 7) ~> true
#check_simp (3 : Fin 7) == (5 : Fin 7) ~> false
#check_simp (3 : Fin 7) != (1 : Fin 7) ~> true
#check_simp (3 : Fin 7) != (3 : Fin 7) ~> false
#check_simp (3 : Fin 7) != (5 : Fin 7) ~> true
Loading