Skip to content

Commit

Permalink
feat: Float32
Browse files Browse the repository at this point in the history
This PR adds support for `Float32` to the Lean runtime.

We need an update stage0, and then remove `#exit` from new
`Float32.lean` file.
  • Loading branch information
leodemoura committed Dec 9, 2024
1 parent 520d4b6 commit ba71869
Show file tree
Hide file tree
Showing 16 changed files with 382 additions and 56 deletions.
1 change: 1 addition & 0 deletions src/Init/Data.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import Init.Data.Fin
import Init.Data.UInt
import Init.Data.SInt
import Init.Data.Float
import Init.Data.Float32
import Init.Data.Option
import Init.Data.Ord
import Init.Data.Random
Expand Down
180 changes: 180 additions & 0 deletions src/Init/Data/Float32.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/-
Copyright (c) 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Core
import Init.Data.Int.Basic
import Init.Data.ToString.Basic
import Init.Data.Float

/-
#exit -- TODO: Remove after update stage0
-- Just show FloatSpec is inhabited.
opaque float32Spec : FloatSpec := {
float := Unit,
val := (),
lt := fun _ _ => True,
le := fun _ _ => True,
decLt := fun _ _ => inferInstanceAs (Decidable True),
decLe := fun _ _ => inferInstanceAs (Decidable True)
}
/-- Native floating point type, corresponding to the IEEE 754 *binary32* format
(`float` in C or `f32` in Rust). -/
structure Float32 where
val : float32Spec.float
instance : Nonempty Float32 := ⟨{ val := float32Spec.val }⟩
@[extern "lean_float32_add"] opaque Float32.add : Float32 → Float32 → Float32
@[extern "lean_float32_sub"] opaque Float32.sub : Float32 → Float32 → Float32
@[extern "lean_float32_mul"] opaque Float32.mul : Float32 → Float32 → Float32
@[extern "lean_float32_div"] opaque Float32.div : Float32 → Float32 → Float32
@[extern "lean_float32_negate"] opaque Float32.neg : Float32 → Float32
set_option bootstrap.genMatcherCode false
def Float32.lt : Float32 → Float32 → Prop := fun a b =>
match a, b with
| ⟨a⟩, ⟨b⟩ => float32Spec.lt a b
def Float32.le : Float32 → Float32 → Prop := fun a b =>
float32Spec.le a.val b.val
/--
Raw transmutation from `UInt32`.
Float32s and UInts have the same endianness on all supported platforms.
IEEE 754 very precisely specifies the bit layout of floats.
-/
@[extern "lean_float32_of_bits"] opaque Float32.ofBits : UInt32 → Float32
/--
Raw transmutation to `UInt32`.
Float32s and UInts have the same endianness on all supported platforms.
IEEE 754 very precisely specifies the bit layout of floats.
Note that this function is distinct from `Float32.toUInt32`, which attempts
to preserve the numeric value, and not the bitwise value.
-/
@[extern "lean_float32_to_bits"] opaque Float32.toBits : Float32 → UInt32
instance : Add Float32 := ⟨Float32.add⟩
instance : Sub Float32 := ⟨Float32.sub⟩
instance : Mul Float32 := ⟨Float32.mul⟩
instance : Div Float32 := ⟨Float32.div⟩
instance : Neg Float32 := ⟨Float32.neg⟩
instance : LT Float32 := ⟨Float32.lt⟩
instance : LE Float32 := ⟨Float32.le⟩
/-- Note: this is not reflexive since `NaN != NaN`.-/
@[extern "lean_float32_beq"] opaque Float32.beq (a b : Float32) : Bool
instance : BEq Float32 := ⟨Float32.beq⟩
@[extern "lean_float32_decLt"] opaque Float32.decLt (a b : Float32) : Decidable (a < b) :=
match a, b with
| ⟨a⟩, ⟨b⟩ => float32Spec.decLt a b
@[extern "lean_float32_decLe"] opaque Float32.decLe (a b : Float32) : Decidable (a ≤ b) :=
match a, b with
| ⟨a⟩, ⟨b⟩ => float32Spec.decLe a b
instance float32DecLt (a b : Float32) : Decidable (a < b) := Float32.decLt a b
instance float32DecLe (a b : Float32) : Decidable (a ≤ b) := Float32.decLe a b
@[extern "lean_float32_to_string"] opaque Float32.toString : Float32 → String
/-- If the given float is non-negative, truncates the value to the nearest non-negative integer.
If negative or NaN, returns `0`.
If larger than the maximum value for `UInt8` (including Inf), returns the maximum value of `UInt8`
(i.e. `UInt8.size - 1`).
-/
@[extern "lean_float32_to_uint8"] opaque Float32.toUInt8 : Float32 → UInt8
/-- If the given float is non-negative, truncates the value to the nearest non-negative integer.
If negative or NaN, returns `0`.
If larger than the maximum value for `UInt16` (including Inf), returns the maximum value of `UInt16`
(i.e. `UInt16.size - 1`).
-/
@[extern "lean_float32_to_uint16"] opaque Float32.toUInt16 : Float32 → UInt16
/-- If the given float is non-negative, truncates the value to the nearest non-negative integer.
If negative or NaN, returns `0`.
If larger than the maximum value for `UInt32` (including Inf), returns the maximum value of `UInt32`
(i.e. `UInt32.size - 1`).
-/
@[extern "lean_float32_to_uint32"] opaque Float32.toUInt32 : Float32 → UInt32
/-- If the given float is non-negative, truncates the value to the nearest non-negative integer.
If negative or NaN, returns `0`.
If larger than the maximum value for `UInt64` (including Inf), returns the maximum value of `UInt64`
(i.e. `UInt64.size - 1`).
-/
@[extern "lean_float32_to_uint64"] opaque Float32.toUInt64 : Float32 → UInt64
/-- If the given float is non-negative, truncates the value to the nearest non-negative integer.
If negative or NaN, returns `0`.
If larger than the maximum value for `USize` (including Inf), returns the maximum value of `USize`
(i.e. `USize.size - 1`). This value is platform dependent).
-/
@[extern "lean_float32_to_usize"] opaque Float32.toUSize : Float32 → USize
@[extern "lean_float32_isnan"] opaque Float32.isNaN : Float32 → Bool
@[extern "lean_float32_isfinite"] opaque Float32.isFinite : Float32 → Bool
@[extern "lean_float32_isinf"] opaque Float32.isInf : Float32 → Bool
/-- Splits the given float `x` into a significand/exponent pair `(s, i)`
such that `x = s * 2^i` where `s ∈ (-1;-0.5] ∪ [0.5; 1)`.
Returns an undefined value if `x` is not finite.
-/
@[extern "lean_float32_frexp"] opaque Float32.frExp : Float32 → Float32 × Int
instance : ToString Float32 where
toString := Float32.toString
@[extern "lean_uint64_to_float"] opaque UInt64.toFloat32 (n : UInt64) : Float32
instance : Inhabited Float32 where
default := UInt64.toFloat32 0
instance : Repr Float32 where
reprPrec n prec := if n < UInt64.toFloat32 0 then Repr.addAppParen (toString n) prec else toString n
instance : ReprAtom Float32 := ⟨⟩
@[extern "sinf"] opaque Float32.sin : Float32 → Float32
@[extern "cosf"] opaque Float32.cos : Float32 → Float32
@[extern "tanf"] opaque Float32.tan : Float32 → Float32
@[extern "asinf"] opaque Float32.asin : Float32 → Float32
@[extern "acosf"] opaque Float32.acos : Float32 → Float32
@[extern "atanf"] opaque Float32.atan : Float32 → Float32
@[extern "atan2f"] opaque Float32.atan2 : Float32 → Float32 → Float32
@[extern "sinhf"] opaque Float32.sinh : Float32 → Float32
@[extern "coshf"] opaque Float32.cosh : Float32 → Float32
@[extern "tanhf"] opaque Float32.tanh : Float32 → Float32
@[extern "asinhf"] opaque Float32.asinh : Float32 → Float32
@[extern "acoshf"] opaque Float32.acosh : Float32 → Float32
@[extern "atanhf"] opaque Float32.atanh : Float32 → Float32
@[extern "expf"] opaque Float32.exp : Float32 → Float32
@[extern "exp2f"] opaque Float32.exp2 : Float32 → Float32
@[extern "logf"] opaque Float32.log : Float32 → Float32
@[extern "log2f"] opaque Float32.log2 : Float32 → Float32
@[extern "log10f"] opaque Float32.log10 : Float32 → Float32
@[extern "powf"] opaque Float32.pow : Float32 → Float32 → Float32
@[extern "sqrtf"] opaque Float32.sqrt : Float32 → Float32
@[extern "cbrtf"] opaque Float32.cbrt : Float32 → Float32
@[extern "ceilf"] opaque Float32.ceil : Float32 → Float32
@[extern "floorf"] opaque Float32.floor : Float32 → Float32
@[extern "roundf"] opaque Float32.round : Float32 → Float32
@[extern "fabsf"] opaque Float32.abs : Float32 → Float32
instance : HomogeneousPow Float32 := ⟨Float32.pow⟩
instance : Min Float32 := minOfLe
instance : Max Float32 := maxOfLe
/--
Efficiently computes `x * 2^i`.
-/
@[extern "lean_float32_scaleb"]
opaque Float32.scaleB (x : Float32) (i : @& Int) : Float32
-/
28 changes: 16 additions & 12 deletions src/Lean/Compiler/IR/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ inductive IRType where
| irrelevant | object | tobject
| struct (leanTypeName : Option Name) (types : Array IRType) : IRType
| union (leanTypeName : Name) (types : Array IRType) : IRType
| float32
deriving Inhabited, Repr

namespace IRType

partial def beq : IRType → IRType → Bool
| float, float => true
| float32, float32 => true
| uint8, uint8 => true
| uint16, uint16 => true
| uint32, uint32 => true
Expand All @@ -104,13 +106,14 @@ partial def beq : IRType → IRType → Bool
instance : BEq IRType := ⟨beq⟩

def isScalar : IRType → Bool
| float => true
| uint8 => true
| uint16 => true
| uint32 => true
| uint64 => true
| usize => true
| _ => false
| float => true
| float32 => true
| uint8 => true
| uint16 => true
| uint32 => true
| uint64 => true
| usize => true
| _ => false

def isObj : IRType → Bool
| object => true
Expand Down Expand Up @@ -611,10 +614,11 @@ def mkIf (x : VarId) (t e : FnBody) : FnBody :=

def getUnboxOpName (t : IRType) : String :=
match t with
| IRType.usize => "lean_unbox_usize"
| IRType.uint32 => "lean_unbox_uint32"
| IRType.uint64 => "lean_unbox_uint64"
| IRType.float => "lean_unbox_float"
| _ => "lean_unbox"
| IRType.usize => "lean_unbox_usize"
| IRType.uint32 => "lean_unbox_uint32"
| IRType.uint64 => "lean_unbox_uint64"
| IRType.float => "lean_unbox_float"
| IRType.float32 => "lean_unbox_float32"
| _ => "lean_unbox"

end Lean.IR
38 changes: 21 additions & 17 deletions src/Lean/Compiler/IR/EmitC.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def emitArg (x : Arg) : M Unit :=

def toCType : IRType → String
| IRType.float => "double"
| IRType.float32 => "float"
| IRType.uint8 => "uint8_t"
| IRType.uint16 => "uint16_t"
| IRType.uint32 => "uint32_t"
Expand Down Expand Up @@ -311,12 +312,13 @@ def emitUSet (x : VarId) (n : Nat) (y : VarId) : M Unit := do

def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) (t : IRType) : M Unit := do
match t with
| IRType.float => emit "lean_ctor_set_float"
| IRType.uint8 => emit "lean_ctor_set_uint8"
| IRType.uint16 => emit "lean_ctor_set_uint16"
| IRType.uint32 => emit "lean_ctor_set_uint32"
| IRType.uint64 => emit "lean_ctor_set_uint64"
| _ => throw "invalid instruction";
| IRType.float => emit "lean_ctor_set_float"
| IRType.float32 => emit "lean_ctor_set_float32"
| IRType.uint8 => emit "lean_ctor_set_uint8"
| IRType.uint16 => emit "lean_ctor_set_uint16"
| IRType.uint32 => emit "lean_ctor_set_uint32"
| IRType.uint64 => emit "lean_ctor_set_uint64"
| _ => throw "invalid instruction";
emit "("; emit x; emit ", "; emitOffset n offset; emit ", "; emit y; emitLn ");"

def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit := do
Expand Down Expand Up @@ -386,12 +388,13 @@ def emitUProj (z : VarId) (i : Nat) (x : VarId) : M Unit := do
def emitSProj (z : VarId) (t : IRType) (n offset : Nat) (x : VarId) : M Unit := do
emitLhs z;
match t with
| IRType.float => emit "lean_ctor_get_float"
| IRType.uint8 => emit "lean_ctor_get_uint8"
| IRType.uint16 => emit "lean_ctor_get_uint16"
| IRType.uint32 => emit "lean_ctor_get_uint32"
| IRType.uint64 => emit "lean_ctor_get_uint64"
| _ => throw "invalid instruction"
| IRType.float => emit "lean_ctor_get_float"
| IRType.float32 => emit "lean_ctor_get_float32"
| IRType.uint8 => emit "lean_ctor_get_uint8"
| IRType.uint16 => emit "lean_ctor_get_uint16"
| IRType.uint32 => emit "lean_ctor_get_uint32"
| IRType.uint64 => emit "lean_ctor_get_uint64"
| _ => throw "invalid instruction"
emit "("; emit x; emit ", "; emitOffset n offset; emitLn ");"

def toStringArgs (ys : Array Arg) : List String :=
Expand Down Expand Up @@ -446,11 +449,12 @@ def emitApp (z : VarId) (f : VarId) (ys : Array Arg) : M Unit :=

def emitBoxFn (xType : IRType) : M Unit :=
match xType with
| IRType.usize => emit "lean_box_usize"
| IRType.uint32 => emit "lean_box_uint32"
| IRType.uint64 => emit "lean_box_uint64"
| IRType.float => emit "lean_box_float"
| _ => emit "lean_box"
| IRType.usize => emit "lean_box_usize"
| IRType.uint32 => emit "lean_box_uint32"
| IRType.uint64 => emit "lean_box_uint64"
| IRType.float => emit "lean_box_float"
| IRType.float32 => emit "lean_box_float32"
| _ => emit "lean_box"

def emitBox (z : VarId) (x : VarId) (xType : IRType) : M Unit := do
emitLhs z; emitBoxFn xType; emit "("; emit x; emitLn ");"
Expand Down
49 changes: 27 additions & 22 deletions src/Lean/Compiler/IR/EmitLLVM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def callLeanCtorSetTag (builder : LLVM.Builder llvmctx)
def toLLVMType (t : IRType) : M llvmctx (LLVM.LLVMType llvmctx) := do
match t with
| IRType.float => LLVM.doubleTypeInContext llvmctx
| IRType.float32 => LLVM.floatTypeInContext llvmctx
| IRType.uint8 => LLVM.intTypeInContext llvmctx 8
| IRType.uint16 => LLVM.intTypeInContext llvmctx 16
| IRType.uint32 => LLVM.intTypeInContext llvmctx 32
Expand Down Expand Up @@ -817,12 +818,13 @@ def emitSProj (builder : LLVM.Builder llvmctx)
(z : VarId) (t : IRType) (n offset : Nat) (x : VarId) : M llvmctx Unit := do
let (fnName, retty) ←
match t with
| IRType.float => pure ("lean_ctor_get_float", ← LLVM.doubleTypeInContext llvmctx)
| IRType.uint8 => pure ("lean_ctor_get_uint8", ← LLVM.i8Type llvmctx)
| IRType.uint16 => pure ("lean_ctor_get_uint16", ← LLVM.i16Type llvmctx)
| IRType.uint32 => pure ("lean_ctor_get_uint32", ← LLVM.i32Type llvmctx)
| IRType.uint64 => pure ("lean_ctor_get_uint64", ← LLVM.i64Type llvmctx)
| _ => throw s!"Invalid type for lean_ctor_get: '{t}'"
| IRType.float => pure ("lean_ctor_get_float", ← LLVM.doubleTypeInContext llvmctx)
| IRType.float32 => pure ("lean_ctor_get_float32", ← LLVM.floatTypeInContext llvmctx)
| IRType.uint8 => pure ("lean_ctor_get_uint8", ← LLVM.i8Type llvmctx)
| IRType.uint16 => pure ("lean_ctor_get_uint16", ← LLVM.i16Type llvmctx)
| IRType.uint32 => pure ("lean_ctor_get_uint32", ← LLVM.i32Type llvmctx)
| IRType.uint64 => pure ("lean_ctor_get_uint64", ← LLVM.i64Type llvmctx)
| _ => throw s!"Invalid type for lean_ctor_get: '{t}'"
let argtys := #[ ← LLVM.voidPtrType llvmctx, ← LLVM.unsignedType llvmctx]
let fn ← getOrCreateFunctionPrototype (← getLLVMModule) retty fnName argtys
let xval ← emitLhsVal builder x
Expand Down Expand Up @@ -862,11 +864,12 @@ def emitBox (builder : LLVM.Builder llvmctx) (z : VarId) (x : VarId) (xType : IR
let xv ← emitLhsVal builder x
let (fnName, argTy, xv) ←
match xType with
| IRType.usize => pure ("lean_box_usize", ← LLVM.size_tType llvmctx, xv)
| IRType.uint32 => pure ("lean_box_uint32", ← LLVM.i32Type llvmctx, xv)
| IRType.uint64 => pure ("lean_box_uint64", ← LLVM.size_tType llvmctx, xv)
| IRType.float => pure ("lean_box_float", ← LLVM.doubleTypeInContext llvmctx, xv)
| _ => do
| IRType.usize => pure ("lean_box_usize", ← LLVM.size_tType llvmctx, xv)
| IRType.uint32 => pure ("lean_box_uint32", ← LLVM.i32Type llvmctx, xv)
| IRType.uint64 => pure ("lean_box_uint64", ← LLVM.size_tType llvmctx, xv)
| IRType.float => pure ("lean_box_float", ← LLVM.doubleTypeInContext llvmctx, xv)
| IRType.float32 => pure ("lean_box_float32", ← LLVM.floatTypeInContext llvmctx, xv)
| _ =>
-- sign extend smaller values into i64
let xv ← LLVM.buildSext builder xv (← LLVM.size_tType llvmctx)
pure ("lean_box", ← LLVM.size_tType llvmctx, xv)
Expand All @@ -892,11 +895,12 @@ def callUnboxForType (builder : LLVM.Builder llvmctx)
(retName : String := "") : M llvmctx (LLVM.Value llvmctx) := do
let (fnName, retty) ←
match t with
| IRType.usize => pure ("lean_unbox_usize", ← toLLVMType t)
| IRType.uint32 => pure ("lean_unbox_uint32", ← toLLVMType t)
| IRType.uint64 => pure ("lean_unbox_uint64", ← toLLVMType t)
| IRType.float => pure ("lean_unbox_float", ← toLLVMType t)
| _ => pure ("lean_unbox", ← LLVM.size_tType llvmctx)
| IRType.usize => pure ("lean_unbox_usize", ← toLLVMType t)
| IRType.uint32 => pure ("lean_unbox_uint32", ← toLLVMType t)
| IRType.uint64 => pure ("lean_unbox_uint64", ← toLLVMType t)
| IRType.float => pure ("lean_unbox_float", ← toLLVMType t)
| IRType.float32 => pure ("lean_unbox_float32", ← toLLVMType t)
| _ => pure ("lean_unbox", ← LLVM.size_tType llvmctx)
let argtys := #[← LLVM.voidPtrType llvmctx ]
let fn ← getOrCreateFunctionPrototype (← getLLVMModule) retty fnName argtys
let fnty ← LLVM.functionType retty argtys
Expand Down Expand Up @@ -1041,12 +1045,13 @@ def emitJmp (builder : LLVM.Builder llvmctx) (jp : JoinPointId) (xs : Array Arg)
def emitSSet (builder : LLVM.Builder llvmctx) (x : VarId) (n : Nat) (offset : Nat) (y : VarId) (t : IRType) : M llvmctx Unit := do
let (fnName, setty) ←
match t with
| IRType.float => pure ("lean_ctor_set_float", ← LLVM.doubleTypeInContext llvmctx)
| IRType.uint8 => pure ("lean_ctor_set_uint8", ← LLVM.i8Type llvmctx)
| IRType.uint16 => pure ("lean_ctor_set_uint16", ← LLVM.i16Type llvmctx)
| IRType.uint32 => pure ("lean_ctor_set_uint32", ← LLVM.i32Type llvmctx)
| IRType.uint64 => pure ("lean_ctor_set_uint64", ← LLVM.i64Type llvmctx)
| _ => throw s!"invalid type for 'lean_ctor_set': '{t}'"
| IRType.float => pure ("lean_ctor_set_float", ← LLVM.doubleTypeInContext llvmctx)
| IRType.float32 => pure ("lean_ctor_set_float32", ← LLVM.floatTypeInContext llvmctx)
| IRType.uint8 => pure ("lean_ctor_set_uint8", ← LLVM.i8Type llvmctx)
| IRType.uint16 => pure ("lean_ctor_set_uint16", ← LLVM.i16Type llvmctx)
| IRType.uint32 => pure ("lean_ctor_set_uint32", ← LLVM.i32Type llvmctx)
| IRType.uint64 => pure ("lean_ctor_set_uint64", ← LLVM.i64Type llvmctx)
| _ => throw s!"invalid type for 'lean_ctor_set': '{t}'"
let argtys := #[ ← LLVM.voidPtrType llvmctx, ← LLVM.unsignedType llvmctx, setty]
let retty ← LLVM.voidType llvmctx
let fn ← getOrCreateFunctionPrototype (← getLLVMModule) retty fnName argtys
Expand Down
Loading

0 comments on commit ba71869

Please sign in to comment.