Skip to content

Commit

Permalink
switch float math udfs from std to libm
Browse files Browse the repository at this point in the history
  • Loading branch information
leoyvens committed Nov 29, 2024
1 parent 70e7e62 commit 4ca9592
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 12 deletions.
1 change: 1 addition & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ regex = { workspace = true, optional = true }
sha2 = { version = "^0.10.1", optional = true }
unicode-segmentation = { version = "^1.7.1", optional = true }
uuid = { version = "1.7", features = ["v4"], optional = true }
libm = "0.2.11"

[dev-dependencies]
arrow = { workspace = true, features = ["test_utils"] }
Expand Down
42 changes: 38 additions & 4 deletions datafusion/functions/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ macro_rules! make_math_unary_udf {
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};

#[allow(unused_imports)]
use $crate::macros::StdFloat as _;

#[derive(Debug)]
pub struct $UDF {
signature: Signature,
Expand Down Expand Up @@ -218,12 +221,16 @@ macro_rules! make_math_unary_udf {
DataType::Float64 => Arc::new(
args[0]
.as_primitive::<Float64Type>()
.unary::<_, Float64Type>(|x: f64| f64::$UNARY_FUNC(x)),
.unary::<_, Float64Type>(|x: f64| {
libm::Libm::<f64>::$UNARY_FUNC(x)
}),
) as ArrayRef,
DataType::Float32 => Arc::new(
args[0]
.as_primitive::<Float32Type>()
.unary::<_, Float32Type>(|x: f32| f32::$UNARY_FUNC(x)),
.unary::<_, Float32Type>(|x: f32| {
libm::Libm::<f32>::$UNARY_FUNC(x)
}),
) as ArrayRef,
other => {
return exec_err!(
Expand Down Expand Up @@ -333,7 +340,7 @@ macro_rules! make_math_binary_udf {
let result = arrow::compute::binary::<_, _, _, Float64Type>(
y,
x,
|y, x| f64::$BINARY_FUNC(y, x),
|y, x| libm::Libm::<f64>::$BINARY_FUNC(y, x),
)?;
Arc::new(result) as _
}
Expand All @@ -343,7 +350,7 @@ macro_rules! make_math_binary_udf {
let result = arrow::compute::binary::<_, _, _, Float32Type>(
y,
x,
|y, x| f32::$BINARY_FUNC(y, x),
|y, x| libm::Libm::<f32>::$BINARY_FUNC(y, x),
)?;
Arc::new(result) as _
}
Expand All @@ -365,3 +372,30 @@ macro_rules! make_math_binary_udf {
}
};
}

/// Adds methods that exist in std but are missing from libm,
/// because they are not platform intrinsics but just conveniences.
pub trait StdFloat<T> {
fn to_radians(x: T) -> T;
fn to_degrees(x: T) -> T;
}

impl StdFloat<f64> for libm::Libm<f64> {
fn to_radians(x: f64) -> f64 {
x.to_radians()
}

fn to_degrees(x: f64) -> f64 {
x.to_degrees()
}
}

impl StdFloat<f32> for libm::Libm<f32> {
fn to_radians(x: f32) -> f32 {
x.to_radians()
}

fn to_degrees(x: f32) -> f32 {
x.to_degrees()
}
}
2 changes: 1 addition & 1 deletion datafusion/functions/src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ make_math_unary_udf!(
LnFunc,
LN,
ln,
ln,
log,
super::ln_order,
super::bounds::unbounded_bounds,
super::get_ln_doc
Expand Down
14 changes: 7 additions & 7 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ select abs(a), abs(b), abs(c) from signed_integers;
query RRR rowsort
select acos(0), acos(0.5), acos(1);
----
1.5707963267948966 1.0471975511965976 0
1.5707963267948966 1.0471975511965979 0

# acos scalar nulls
query R rowsort
Expand Down Expand Up @@ -140,7 +140,7 @@ NaN 9.90349 NaN
query RRR rowsort
select asin(0), asin(0.5), asin(1);
----
0 0.5235987755982988 1.5707963267948966
0 0.5235987755982989 1.5707963267948966

# asin scalar nulls
query R rowsort
Expand Down Expand Up @@ -362,7 +362,7 @@ select round(cos(a), 5), round(cos(b), 5), round(cos(c), 5) from signed_integers
query RRR rowsort
select cosh(1), cosh(2), cosh(3);
----
1.5430806348152437 3.7621956910836314 10.067661995777765
1.543080634815244 3.7621956910836314 10.067661995777765

# cosh scalar nulls
query R rowsort
Expand All @@ -385,7 +385,7 @@ select round(cosh(a), 5), round(cosh(b), 5), round(cosh(c), 5) from small_floats
query RRR rowsort
select exp(0), exp(1), exp(2);
----
1 2.718281828459045 7.38905609893065
1 2.7182818284590455 7.38905609893065

# exp scalar nulls
query R rowsort
Expand Down Expand Up @@ -454,7 +454,7 @@ select floor(a), floor(b), floor(c) from signed_integers;
query RRR rowsort
select ln(1), ln(exp(1)), ln(3);
----
0 1 1.0986122886681098
0 1 1.0986122886681096

# ln scalar nulls
query R rowsort
Expand Down Expand Up @@ -872,7 +872,7 @@ select round(sin(a), 5), round(sin(b), 5), round(sin(c), 5) from small_floats;
query RRR rowsort
select sinh(1), sinh(2), sinh(3);
----
1.1752011936438014 3.6268604078470186 10.017874927409903
1.1752011936438014 3.626860407847019 10.017874927409903

# sinh scalar nulls
query R rowsort
Expand Down Expand Up @@ -937,7 +937,7 @@ NaN
query RRR rowsort
select tan(0), tan(pi() / 6), tan(pi() / 4);
----
0 0.5773502691896256 0.9999999999999999
0 0.5773502691896257 0.9999999999999999

# tan scalar nulls
query R rowsort
Expand Down

0 comments on commit 4ca9592

Please sign in to comment.