Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: ArgsPacker.unpack to return Option #6359

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Lean/Elab/PreDefinition/WF/Fix.lean
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def groupGoalsByFunction (argsPacker : ArgsPacker) (numFuncs : Nat) (goals : Arr
let type ← goal.getType
let (.mdata _ (.app _ param)) := type
| throwError "MVar does not look like a recursive call:{indentExpr type}"
let (funidx, _) ← argsPacker.unpack param
let some (funidx, _) := argsPacker.unpack param
| throwError "Cannot unpack param, unexpected expression:{indentExpr param}"
r := r.modify funidx (·.push goal)
return r

Expand Down
8 changes: 6 additions & 2 deletions src/Lean/Elab/PreDefinition/WF/GuessLex.lean
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,10 @@ def collectRecCalls (unaryPreDef : PreDefinition) (fixedPrefixSize : Nat)
throwError "Insufficient arguments in recursive call"
let arg := args[fixedPrefixSize]!
trace[Elab.definition.wf] "collectRecCalls: {unaryPreDef.declName} ({param}) → {unaryPreDef.declName} ({arg})"
let (caller, params) ← argsPacker.unpack param
let (callee, args) ← argsPacker.unpack arg
let some (caller, params) := argsPacker.unpack param
| throwError "Cannot unpack param, unexpected expression:{indentExpr param}"
let some (callee, args) := argsPacker.unpack arg
| throwError "Cannot unpack arg, unexpected expression:{indentExpr arg}"
RecCallWithContext.create (← getRef) caller (ys ++ params) callee (ys ++ args)

/-- Is the expression a `<`-like comparison of `Nat` expressions -/
Expand Down Expand Up @@ -771,6 +773,8 @@ Main entry point of this module:
Try to find a lexicographic ordering of the arguments for which the recursive definition
terminates. See the module doc string for a high-level overview.
The `preDefs` are used to determine arity and types of arguments; the bodies are ignored.
-/
def guessLex (preDefs : Array PreDefinition) (unaryPreDef : PreDefinition)
(fixedPrefixSize : Nat) (argsPacker : ArgsPacker) :
Expand Down
33 changes: 21 additions & 12 deletions src/Lean/Meta/ArgsPacker.lean
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,18 @@ Unpacks a unary packed argument created with `Unary.pack`.
Throws an error if the expression is not of that form.
-/
def unpack (arity : Nat) (e : Expr) : MetaM (Array Expr) := do
def unpack (arity : Nat) (e : Expr) : Option (Array Expr) := do
let mut e := e
let mut args := #[]
while args.size + 1 < arity do
if e.isAppOfArity ``PSigma.mk 4 then
args := args.push (e.getArg! 2)
e := e.getArg! 3
else
throwError "Unexpected expression while unpacking n-ary argument"
none
args := args.push e
return args


/--
Given a (dependent) tuple `t` (using `PSigma`) of the given arity.
Return an array containing its "elements".
Expand Down Expand Up @@ -258,7 +257,7 @@ argument and function index.
Throws an error if the expression is not of that form.
-/
def unpack (numFuncs : Nat) (expr : Expr) : MetaM (Nat × Expr) := do
def unpack (numFuncs : Nat) (expr : Expr) : Option (Nat × Expr) := do
let mut funidx := 0
let mut e := expr
while funidx + 1 < numFuncs do
Expand All @@ -269,7 +268,7 @@ def unpack (numFuncs : Nat) (expr : Expr) : MetaM (Nat × Expr) := do
e := e.getArg! 2
break
else
throwError "Unexpected expression while unpacking mutual argument:{indentExpr expr}"
none
return (funidx, e)


Expand Down Expand Up @@ -377,14 +376,17 @@ and `(z : C) → R₂[z]`, returns an expression of type
(x : A ⊕' C) → (match x with | .inl x => R₁[x] | .inr R₂[z])
```
-/
def uncurry (es : Array Expr) : MetaM Expr := do
let types ← es.mapM inferType
let resultType ← uncurryType types
def uncurryWithType (resultType : Expr) (es : Array Expr) : MetaM Expr := do
forallBoundedTelescope resultType (some 1) fun xs codomain => do
let #[x] := xs | unreachable!
let value ← casesOn x codomain es.toList
mkLambdaFVars #[x] value

def uncurry (es : Array Expr) : MetaM Expr := do
let types ← es.mapM inferType
let resultType ← uncurryType types
uncurryWithType resultType es

/--
Given unary expressions `e₁`, `e₂` with types `(x : A) → R`
and `(z : C) → R`, returns an expression of type
Expand Down Expand Up @@ -414,14 +416,18 @@ def curryType (n : Nat) (type : Expr) : MetaM (Array Expr) := do

end Mutual

-- Now for the main definitions in this moduleo
-- Now for the main definitions in this module

/-- The number of functions being packed -/
def numFuncs (argsPacker : ArgsPacker) : Nat := argsPacker.varNamess.size

/-- The arities of the functions being packed -/
def arities (argsPacker : ArgsPacker) : Array Nat := argsPacker.varNamess.map (·.size)

def onlyOneUnary (argsPacker : ArgsPacker) :=
argsPacker.varNamess.size = 1 &&
argsPacker.varNamess[0]!.size = 1

def pack (argsPacker : ArgsPacker) (domain : Expr) (fidx : Nat) (args : Array Expr)
: MetaM Expr := do
assert! fidx < argsPacker.numFuncs
Expand All @@ -436,14 +442,13 @@ return the function index that is called and the arguments individually.
We expect precisely the expressions produced by `pack`, with manifest
`PSum.inr`, `PSum.inl` and `PSigma.mk` constructors, and thus take them apart
rather than using projectinos.
rather than using projections.
-/
def unpack (argsPacker : ArgsPacker) (e : Expr) : MetaM (Nat × Array Expr) := do
def unpack (argsPacker : ArgsPacker) (e : Expr) : Option (Nat × Array Expr) := do
let (funidx, e) ← Mutual.unpack argsPacker.numFuncs e
let args ← Unary.unpack argsPacker.varNamess[funidx]!.size e
return (funidx, args)


/--
Given types `(x : A) → (y : B[x]) → R₁[x,y]` and `(z : C) → R₂[z]`, returns the type uncurried type
```
Expand All @@ -465,6 +470,10 @@ def uncurry (argsPacker : ArgsPacker) (es : Array Expr) : MetaM Expr := do
let unary ← (Array.zipWith argsPacker.varNamess es Unary.uncurry).mapM id
Mutual.uncurry unary

def uncurryWithType (argsPacker : ArgsPacker) (resultType : Expr) (es : Array Expr) : MetaM Expr := do
let unary ← (Array.zipWith argsPacker.varNamess es Unary.uncurry).mapM id
Mutual.uncurryWithType resultType unary

/--
Given expressions `e₁`, `e₂` with types `(x : A) → (y : B[x]) → R`
and `(z : C) → R`, returns an expression of type
Expand Down
3 changes: 2 additions & 1 deletion src/Lean/Meta/Tactic/FunInd.lean
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,8 @@ def cleanPackedArgs (eqnInfo : WF.EqnInfo) (value : Expr) : MetaM Expr := do
let args := e.getAppArgs
if eqnInfo.fixedPrefixSize + 1 ≤ args.size then
let packedArg := args.back!
let (i, unpackedArgs) ← eqnInfo.argsPacker.unpack packedArg
let some (i, unpackedArgs) := eqnInfo.argsPacker.unpack packedArg
| throwError "Unexpected packedArg:{indentExpr packedArg}"
let e' := .const eqnInfo.declNames[i]! e.getAppFn.constLevels!
let e' := mkAppN e' args.pop
let e' := mkAppN e' unpackedArgs
Expand Down
Loading