Skip to content

Commit

Permalink
refactor: ArgsPacker.unpack to return Option (#6359)
Browse files Browse the repository at this point in the history
so that it can be used in pure code and that the error message can be
adjusted
  • Loading branch information
nomeata authored Dec 10, 2024
1 parent 9386511 commit d27c5af
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 16 deletions.
3 changes: 2 additions & 1 deletion src/Lean/Elab/PreDefinition/WF/Fix.lean
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def groupGoalsByFunction (argsPacker : ArgsPacker) (numFuncs : Nat) (goals : Arr
let type ← goal.getType
let (.mdata _ (.app _ param)) := type
| throwError "MVar does not look like a recursive call:{indentExpr type}"
let (funidx, _) ← argsPacker.unpack param
let some (funidx, _) := argsPacker.unpack param
| throwError "Cannot unpack param, unexpected expression:{indentExpr param}"
r := r.modify funidx (·.push goal)
return r

Expand Down
8 changes: 6 additions & 2 deletions src/Lean/Elab/PreDefinition/WF/GuessLex.lean
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,10 @@ def collectRecCalls (unaryPreDef : PreDefinition) (fixedPrefixSize : Nat)
throwError "Insufficient arguments in recursive call"
let arg := args[fixedPrefixSize]!
trace[Elab.definition.wf] "collectRecCalls: {unaryPreDef.declName} ({param}) → {unaryPreDef.declName} ({arg})"
let (caller, params) ← argsPacker.unpack param
let (callee, args) ← argsPacker.unpack arg
let some (caller, params) := argsPacker.unpack param
| throwError "Cannot unpack param, unexpected expression:{indentExpr param}"
let some (callee, args) := argsPacker.unpack arg
| throwError "Cannot unpack arg, unexpected expression:{indentExpr arg}"
RecCallWithContext.create (← getRef) caller (ys ++ params) callee (ys ++ args)

/-- Is the expression a `<`-like comparison of `Nat` expressions -/
Expand Down Expand Up @@ -771,6 +773,8 @@ Main entry point of this module:
Try to find a lexicographic ordering of the arguments for which the recursive definition
terminates. See the module doc string for a high-level overview.
The `preDefs` are used to determine arity and types of arguments; the bodies are ignored.
-/
def guessLex (preDefs : Array PreDefinition) (unaryPreDef : PreDefinition)
(fixedPrefixSize : Nat) (argsPacker : ArgsPacker) :
Expand Down
33 changes: 21 additions & 12 deletions src/Lean/Meta/ArgsPacker.lean
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,18 @@ Unpacks a unary packed argument created with `Unary.pack`.
Throws an error if the expression is not of that form.
-/
def unpack (arity : Nat) (e : Expr) : MetaM (Array Expr) := do
def unpack (arity : Nat) (e : Expr) : Option (Array Expr) := do
let mut e := e
let mut args := #[]
while args.size + 1 < arity do
if e.isAppOfArity ``PSigma.mk 4 then
args := args.push (e.getArg! 2)
e := e.getArg! 3
else
throwError "Unexpected expression while unpacking n-ary argument"
none
args := args.push e
return args


/--
Given a (dependent) tuple `t` (using `PSigma`) of the given arity.
Return an array containing its "elements".
Expand Down Expand Up @@ -258,7 +257,7 @@ argument and function index.
Throws an error if the expression is not of that form.
-/
def unpack (numFuncs : Nat) (expr : Expr) : MetaM (Nat × Expr) := do
def unpack (numFuncs : Nat) (expr : Expr) : Option (Nat × Expr) := do
let mut funidx := 0
let mut e := expr
while funidx + 1 < numFuncs do
Expand All @@ -269,7 +268,7 @@ def unpack (numFuncs : Nat) (expr : Expr) : MetaM (Nat × Expr) := do
e := e.getArg! 2
break
else
throwError "Unexpected expression while unpacking mutual argument:{indentExpr expr}"
none
return (funidx, e)


Expand Down Expand Up @@ -377,14 +376,17 @@ and `(z : C) → R₂[z]`, returns an expression of type
(x : A ⊕' C) → (match x with | .inl x => R₁[x] | .inr R₂[z])
```
-/
def uncurry (es : Array Expr) : MetaM Expr := do
let types ← es.mapM inferType
let resultType ← uncurryType types
def uncurryWithType (resultType : Expr) (es : Array Expr) : MetaM Expr := do
forallBoundedTelescope resultType (some 1) fun xs codomain => do
let #[x] := xs | unreachable!
let value ← casesOn x codomain es.toList
mkLambdaFVars #[x] value

def uncurry (es : Array Expr) : MetaM Expr := do
let types ← es.mapM inferType
let resultType ← uncurryType types
uncurryWithType resultType es

/--
Given unary expressions `e₁`, `e₂` with types `(x : A) → R`
and `(z : C) → R`, returns an expression of type
Expand Down Expand Up @@ -414,14 +416,18 @@ def curryType (n : Nat) (type : Expr) : MetaM (Array Expr) := do

end Mutual

-- Now for the main definitions in this moduleo
-- Now for the main definitions in this module

/-- The number of functions being packed -/
def numFuncs (argsPacker : ArgsPacker) : Nat := argsPacker.varNamess.size

/-- The arities of the functions being packed -/
def arities (argsPacker : ArgsPacker) : Array Nat := argsPacker.varNamess.map (·.size)

def onlyOneUnary (argsPacker : ArgsPacker) :=
argsPacker.varNamess.size = 1 &&
argsPacker.varNamess[0]!.size = 1

def pack (argsPacker : ArgsPacker) (domain : Expr) (fidx : Nat) (args : Array Expr)
: MetaM Expr := do
assert! fidx < argsPacker.numFuncs
Expand All @@ -436,14 +442,13 @@ return the function index that is called and the arguments individually.
We expect precisely the expressions produced by `pack`, with manifest
`PSum.inr`, `PSum.inl` and `PSigma.mk` constructors, and thus take them apart
rather than using projectinos.
rather than using projections.
-/
def unpack (argsPacker : ArgsPacker) (e : Expr) : MetaM (Nat × Array Expr) := do
def unpack (argsPacker : ArgsPacker) (e : Expr) : Option (Nat × Array Expr) := do
let (funidx, e) ← Mutual.unpack argsPacker.numFuncs e
let args ← Unary.unpack argsPacker.varNamess[funidx]!.size e
return (funidx, args)


/--
Given types `(x : A) → (y : B[x]) → R₁[x,y]` and `(z : C) → R₂[z]`, returns the type uncurried type
```
Expand All @@ -465,6 +470,10 @@ def uncurry (argsPacker : ArgsPacker) (es : Array Expr) : MetaM Expr := do
let unary ← (Array.zipWith argsPacker.varNamess es Unary.uncurry).mapM id
Mutual.uncurry unary

def uncurryWithType (argsPacker : ArgsPacker) (resultType : Expr) (es : Array Expr) : MetaM Expr := do
let unary ← (Array.zipWith argsPacker.varNamess es Unary.uncurry).mapM id
Mutual.uncurryWithType resultType unary

/--
Given expressions `e₁`, `e₂` with types `(x : A) → (y : B[x]) → R`
and `(z : C) → R`, returns an expression of type
Expand Down
3 changes: 2 additions & 1 deletion src/Lean/Meta/Tactic/FunInd.lean
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,8 @@ def cleanPackedArgs (eqnInfo : WF.EqnInfo) (value : Expr) : MetaM Expr := do
let args := e.getAppArgs
if eqnInfo.fixedPrefixSize + 1 ≤ args.size then
let packedArg := args.back!
let (i, unpackedArgs) ← eqnInfo.argsPacker.unpack packedArg
let some (i, unpackedArgs) := eqnInfo.argsPacker.unpack packedArg
| throwError "Unexpected packedArg:{indentExpr packedArg}"
let e' := .const eqnInfo.declNames[i]! e.getAppFn.constLevels!
let e' := mkAppN e' args.pop
let e' := mkAppN e' unpackedArgs
Expand Down

0 comments on commit d27c5af

Please sign in to comment.