Skip to content

Commit

Permalink
feat: _ separators in numeric literals (#6204)
Browse files Browse the repository at this point in the history
This PR allows `_` in numeric literals as a separator. For example,
`1_000_000`, `0xff_ff` or `0b_10_11_01_00`. New lexical syntax:
```text
numeral10 : [0-9]+ ("_"+ [0-9]+)*
numeral2  : "0" [bB] ("_"* [0-1]+)+
numeral8  : "0" [oO] ("_"* [0-7]+)+
numeral16 : "0" [xX] ("_"* hex_char+)+
float     : numeral10 "." numeral10? [eE[+-]numeral10]
```

Closes #6199
  • Loading branch information
kmill authored Dec 8, 2024
1 parent 6abb8aa commit 4cd5032
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 17 deletions.
11 changes: 6 additions & 5 deletions doc/lexical_structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,16 @@ Numeric literals can be specified in various bases.

```
numeral : numeral10 | numeral2 | numeral8 | numeral16
numeral10 : [0-9]+
numeral2 : "0" [bB] [0-1]+
numeral8 : "0" [oO] [0-7]+
numeral16 : "0" [xX] hex_char+
numeral10 : [0-9]+ ("_"+ [0-9]+)*
numeral2 : "0" [bB] ("_"* [0-1]+)+
numeral8 : "0" [oO] ("_"* [0-7]+)+
numeral16 : "0" [xX] ("_"* hex_char+)+
```

Floating point literals are also possible with optional exponent:

```
float : [0-9]+ "." [0-9]+ [[eE[+-][0-9]+]
float : numeral10 "." numeral10? [eE[+-]numeral10]
```

For example:
Expand All @@ -147,6 +147,7 @@ constant w : Int := 55
constant x : Nat := 26085
constant y : Nat := 0x65E5
constant z : Float := 2.548123e-05
constant b : Bool := 0b_11_01_10_00
```

Note: that negative numbers are created by applying the "-" negation prefix operator to the number, for example:
Expand Down
13 changes: 12 additions & 1 deletion src/Init/Meta.lean
Original file line number Diff line number Diff line change
Expand Up @@ -679,13 +679,15 @@ private partial def decodeBinLitAux (s : String) (i : String.Pos) (val : Nat) :
let c := s.get i
if c == '0' then decodeBinLitAux s (s.next i) (2*val)
else if c == '1' then decodeBinLitAux s (s.next i) (2*val + 1)
else if c == '_' then decodeBinLitAux s (s.next i) val
else none

private partial def decodeOctalLitAux (s : String) (i : String.Pos) (val : Nat) : Option Nat :=
if s.atEnd i then some val
else
let c := s.get i
if '0' ≤ c && c ≤ '7' then decodeOctalLitAux s (s.next i) (8*val + c.toNat - '0'.toNat)
else if c == '_' then decodeOctalLitAux s (s.next i) val
else none

private def decodeHexDigit (s : String) (i : String.Pos) : Option (Nat × String.Pos) :=
Expand All @@ -700,13 +702,16 @@ private partial def decodeHexLitAux (s : String) (i : String.Pos) (val : Nat) :
if s.atEnd i then some val
else match decodeHexDigit s i with
| some (d, i) => decodeHexLitAux s i (16*val + d)
| none => none
| none =>
if s.get i == '_' then decodeHexLitAux s (s.next i) val
else none

private partial def decodeDecimalLitAux (s : String) (i : String.Pos) (val : Nat) : Option Nat :=
if s.atEnd i then some val
else
let c := s.get i
if '0' ≤ c && c ≤ '9' then decodeDecimalLitAux s (s.next i) (10*val + c.toNat - '0'.toNat)
else if c == '_' then decodeDecimalLitAux s (s.next i) val
else none

def decodeNatLitVal? (s : String) : Option Nat :=
Expand Down Expand Up @@ -773,6 +778,8 @@ where
let c := s.get i
if '0' ≤ c && c ≤ '9' then
decodeAfterExp (s.next i) val e sign (10*exp + c.toNat - '0'.toNat)
else if c == '_' then
decodeAfterExp (s.next i) val e sign exp
else
none

Expand All @@ -793,6 +800,8 @@ where
let c := s.get i
if '0' ≤ c && c ≤ '9' then
decodeAfterDot (s.next i) (10*val + c.toNat - '0'.toNat) (e+1)
else if c == '_' then
decodeAfterDot (s.next i) val e
else if c == 'e' || c == 'E' then
decodeExp (s.next i) val e
else
Expand All @@ -805,6 +814,8 @@ where
let c := s.get i
if '0' ≤ c && c ≤ '9' then
decode (s.next i) (10*val + c.toNat - '0'.toNat)
else if c == '_' then
decode (s.next i) val
else if c == '.' then
decodeAfterDot (s.next i) val 0
else if c == 'e' || c == 'E' then
Expand Down
49 changes: 38 additions & 11 deletions src/Lean/Parser/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -804,18 +804,45 @@ where
else
normalState num c s

/--
Parses a sequence of the form `many (many '_' >> many1 digit)`, but if `needDigit` is true the parsed result must be nonempty.
Note: this does not report that it is expecting `_` if we reach EOI or an unexpected character.
Rationale: this error happens if there is already a `_`, and while sequences of `_` are allowed, it's a bit perverse to suggest extending the sequence.
-/
partial def takeDigitsFn (isDigit : Char → Bool) (expecting : String) (needDigit : Bool) : ParserFn := fun c s =>
let input := c.input
let i := s.pos
if h : input.atEnd i then
if needDigit then
s.mkEOIError [expecting]
else
s
else
let curr := input.get' i h
if curr == '_' then takeDigitsFn isDigit expecting true c (s.next' c.input i h)
else if isDigit curr then takeDigitsFn isDigit expecting false c (s.next' c.input i h)
else if needDigit then s.mkUnexpectedError "unexpected character" (expected := [expecting])
else s

def decimalNumberFn (startPos : String.Pos) (c : ParserContext) : ParserState → ParserState := fun s =>
let s := takeWhileFn (fun c => c.isDigit) c s
let s := takeDigitsFn (fun c => c.isDigit) "decimal number" false c s
let input := c.input
let i := s.pos
let curr := input.get i
if curr == '.' || curr == 'e' || curr == 'E' then
if h : input.atEnd i then
mkNodeToken numLitKind startPos c s
else
let curr := input.get' i h
if curr == '.' || curr == 'e' || curr == 'E' then
parseScientific s
else
mkNodeToken numLitKind startPos c s
where
parseScientific s :=
let s := parseOptDot s
let s := parseOptExp s
mkNodeToken scientificLitKind startPos c s
else
mkNodeToken numLitKind startPos c s
where

parseOptDot s :=
let input := c.input
let i := s.pos
Expand All @@ -824,7 +851,7 @@ where
let i := input.next i
let curr := input.get i
if curr.isDigit then
takeWhileFn (fun c => c.isDigit) c (s.setPos i)
takeDigitsFn (fun c => c.isDigit) "decimal number" false c (s.setPos i)
else
s.setPos i
else
Expand All @@ -839,22 +866,22 @@ where
let i := if input.get i == '-' || input.get i == '+' then input.next i else i
let curr := input.get i
if curr.isDigit then
takeWhileFn (fun c => c.isDigit) c (s.setPos i)
takeDigitsFn (fun c => c.isDigit) "decimal number" false c (s.setPos i)
else
s.mkUnexpectedError "missing exponent digits in scientific literal"
else
s

def binNumberFn (startPos : String.Pos) : ParserFn := fun c s =>
let s := takeWhile1Fn (fun c => c == '0' || c == '1') "binary number" c s
let s := takeDigitsFn (fun c => c == '0' || c == '1') "binary number" true c s
mkNodeToken numLitKind startPos c s

def octalNumberFn (startPos : String.Pos) : ParserFn := fun c s =>
let s := takeWhile1Fn (fun c => '0' ≤ c && c ≤ '7') "octal number" c s
let s := takeDigitsFn (fun c => '0' ≤ c && c ≤ '7') "octal number" true c s
mkNodeToken numLitKind startPos c s

def hexNumberFn (startPos : String.Pos) : ParserFn := fun c s =>
let s := takeWhile1Fn (fun c => ('0' ≤ c && c ≤ '9') || ('a' ≤ c && c ≤ 'f') || ('A' ≤ c && c ≤ 'F')) "hexadecimal number" c s
let s := takeDigitsFn (fun c => ('0' ≤ c && c ≤ '9') || ('a' ≤ c && c ≤ 'f') || ('A' ≤ c && c ≤ 'F')) "hexadecimal number" true c s
mkNodeToken numLitKind startPos c s

def numberFnAux : ParserFn := fun c s =>
Expand Down
107 changes: 107 additions & 0 deletions tests/lean/run/6199.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import Lean
/-!
# `_` separators for numeric literals
RFC: https://github.com/leanprover/lean4/issues/6199
-/
set_option pp.mvars false

open Lean Elab

elab "#term " s:str : command => Command.liftTermElabM <| withRef s do
let t ← Lean.ofExcept <| Parser.runParserCategory (← getEnv) `term s.getString
let e ← Term.elabTermAndSynthesize t none
logInfo m!"{e} : {← Meta.inferType e}"

/-!
Decimal tests
-/
/-- info: 0 : Nat -/
#guard_msgs in #term "0"
/-- info: 1 : Nat -/
#guard_msgs in #term "1"
/-- info: 1000000 : Nat -/
#guard_msgs in #term "1000000"
/-- info: 1000000 : Nat -/
#guard_msgs in #term "1_000_000"
/-- info: 1000000 : Nat -/
#guard_msgs in #term "1__000___000"
/-- error: <input>:1:5: unexpected end of input; expected decimal number -/
#guard_msgs in #term "1_00_"
/-- error: <input>:1:6: unexpected character; expected decimal number -/
#guard_msgs in #term "(1_00_)"
-- Starting with `_` is an identifier:
/--
error: unknown identifier '_10'
---
info: sorry : ?_
-/
#guard_msgs in #term "_10"

/-!
Scientific tests
-/
/-- info: 100000. : Float -/
#guard_msgs in #term "100_000."
/-- info: 100000.0 : Float -/
#guard_msgs in #term "100_000.0"
/-- info: 0. : Float -/
#guard_msgs in #term "0."
-- The decimal parser requires a digit at the start, so the `_` is left over:
/-- error: <input>:1:4: expected end of input -/
#guard_msgs in #term "100._"
/-- info: 100.111111 : Float -/
#guard_msgs in #term "100.111_111"
/-- error: <input>:1:8: unexpected end of input; expected decimal number -/
#guard_msgs in #term "100.111_"
/-- error: <input>:1:9: unexpected character; expected decimal number -/
#guard_msgs in #term "(100.111_)"
/-- info: 100111111e1094 : Float -/
#guard_msgs in #term "100.111_111e1_100"
/-- error: <input>:1:5: unexpected end of input; expected decimal number -/
#guard_msgs in #term "1e10_"
/-- error: <input>:1:6: unexpected character; expected decimal number -/
#guard_msgs in #term "(1e10_)"
/-- error: <input>:1:1: missing exponent digits in scientific literal -/
#guard_msgs in #term "1e_10"

/-!
Base-2 tests
-/
/-- info: 15 : Nat -/
#guard_msgs in #term "0b1111"
/-- info: 15 : Nat -/
#guard_msgs in #term "0b11_11"
/-- info: 15 : Nat -/
#guard_msgs in #term "0b__11__11"
/-- error: <input>:1:7: unexpected end of input; expected binary number -/
#guard_msgs in #term "0b1111_"
/-- error: <input>:1:8: unexpected character; expected binary number -/
#guard_msgs in #term "(0b1111_)"

/-!
Base-8 tests
-/
/-- info: 512 : Nat -/
#guard_msgs in #term "0o1000"
/-- info: 512 : Nat -/
#guard_msgs in #term "0o1_000"
/-- info: 512 : Nat -/
#guard_msgs in #term "0o_1_000"
/-- error: <input>:1:7: unexpected end of input; expected octal number -/
#guard_msgs in #term "0o1000_"
/-- error: <input>:1:8: unexpected character; expected octal number -/
#guard_msgs in #term "(0o1000_)"

/-!
Base-16 tests
-/
/-- info: 4096 : Nat -/
#guard_msgs in #term "0x1000"
/-- info: 4096 : Nat -/
#guard_msgs in #term "0x1_000"
/-- info: 4096 : Nat -/
#guard_msgs in #term "0x_1_000"
/-- error: <input>:1:7: unexpected end of input; expected hexadecimal number -/
#guard_msgs in #term "0x1000_"
/-- error: <input>:1:8: unexpected character; expected hexadecimal number -/
#guard_msgs in #term "(0x1000_)"

0 comments on commit 4cd5032

Please sign in to comment.