Skip to content

Commit

Permalink
More rigorous treatment of floats in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
leoyvens committed Nov 28, 2024
1 parent fc49e8d commit 9942c60
Show file tree
Hide file tree
Showing 27 changed files with 1,141 additions and 1,093 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ arrow-ipc = { version = "53.3.0", default-features = false, features = [
arrow-ord = { version = "53.3.0", default-features = false }
arrow-schema = { version = "53.3.0", default-features = false }
async-trait = "0.1.73"
bigdecimal = "0.4.6"
bytes = "1.4"
chrono = { version = "0.4.38", default-features = false }
ctor = "0.2.0"
Expand Down
1 change: 0 additions & 1 deletion datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ rand = { workspace = true, features = ["small_rng"] }
rand_distr = "0.4.3"
regex = { workspace = true }
rstest = { workspace = true }
rust_decimal = { version = "1.27.0", features = ["tokio-pg"] }
serde_json = { workspace = true }
test-utils = { path = "../../test-utils" }
tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] }
Expand Down
5 changes: 3 additions & 2 deletions datafusion/sqllogictest/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ path = "src/lib.rs"
[dependencies]
arrow = { workspace = true }
async-trait = { workspace = true }
bigdecimal = { workspace = true }
bytes = { workspace = true, optional = true }
chrono = { workspace = true, optional = true }
clap = { version = "4.5.16", features = ["derive", "env"] }
Expand All @@ -50,13 +49,14 @@ log = { workspace = true }
object_store = { workspace = true }
postgres-protocol = { version = "0.6.4", optional = true }
postgres-types = { version = "0.2.4", optional = true }
rust_decimal = { version = "1.27.0" }
rust_decimal = { version = "1.27.0", features = ["tokio-pg"], optional = true }
sqllogictest = "0.23.0"
sqlparser = { workspace = true }
tempfile = { workspace = true }
thiserror = "2.0.0"
tokio = { workspace = true }
tokio-postgres = { version = "0.7.7", optional = true }
ryu = "1.0.18"

[features]
avro = ["datafusion/avro"]
Expand All @@ -66,6 +66,7 @@ postgres = [
"tokio-postgres",
"postgres-types",
"postgres-protocol",
"rust_decimal",
]

[dev-dependencies]
Expand Down
107 changes: 17 additions & 90 deletions datafusion/sqllogictest/src/engines/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
// under the License.

use arrow::datatypes::{i256, Decimal128Type, Decimal256Type, DecimalType};
use bigdecimal::BigDecimal;
use half::f16;
use rust_decimal::prelude::*;

/// Represents a constant for NULL string in your database.
pub const NULL_STR: &str = "NULL";
Expand All @@ -40,17 +38,7 @@ pub(crate) fn varchar_to_str(value: &str) -> String {
}

pub(crate) fn f16_to_str(value: f16) -> String {
if value.is_nan() {
// The sign of NaN can be different depending on platform.
// So the string representation of NaN ignores the sign.
"NaN".to_string()
} else if value == f16::INFINITY {
"Infinity".to_string()
} else if value == f16::NEG_INFINITY {
"-Infinity".to_string()
} else {
big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap())
}
f32_to_str(value.to_f32())
}

pub(crate) fn f32_to_str(value: f32) -> String {
Expand All @@ -63,7 +51,7 @@ pub(crate) fn f32_to_str(value: f32) -> String {
} else if value == f32::NEG_INFINITY {
"-Infinity".to_string()
} else {
big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap())
trim_decimal_trailing_zeros(ryu::Buffer::new().format(value).to_string())
}
}

Expand All @@ -77,94 +65,33 @@ pub(crate) fn f64_to_str(value: f64) -> String {
} else if value == f64::NEG_INFINITY {
"-Infinity".to_string()
} else {
big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap())
trim_decimal_trailing_zeros(ryu::Buffer::new().format(value).to_string())
}
}

pub(crate) fn i128_to_str(value: i128, precision: &u8, scale: &i8) -> String {
big_decimal_to_str(
BigDecimal::from_str(&Decimal128Type::format_decimal(value, *precision, *scale))
.unwrap(),
)
trim_decimal_trailing_zeros(Decimal128Type::format_decimal(value, *precision, *scale))
}

pub(crate) fn i256_to_str(value: i256, precision: &u8, scale: &i8) -> String {
big_decimal_to_str(
BigDecimal::from_str(&Decimal256Type::format_decimal(value, *precision, *scale))
.unwrap(),
)
trim_decimal_trailing_zeros(Decimal256Type::format_decimal(value, *precision, *scale))
}

#[cfg(feature = "postgres")]
pub(crate) fn decimal_to_str(value: Decimal) -> String {
big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap())
}

pub(crate) fn big_decimal_to_str(value: BigDecimal) -> String {
// Round the value to limit the number of decimal places
let value = value.round(12).normalized();
// Format the value to a string
format_big_decimal(value)
pub(crate) fn decimal_to_str(value: rust_decimal::Decimal) -> String {
trim_decimal_trailing_zeros(value.to_string())
}

fn format_big_decimal(value: BigDecimal) -> String {
let (integer, scale) = value.into_bigint_and_exponent();
let mut str = integer.to_str_radix(10);
if scale <= 0 {
// Append zeros to the right of the integer part
str.extend(std::iter::repeat('0').take(scale.unsigned_abs() as usize));
str
} else {
let (sign, unsigned_len, unsigned_str) = if integer.is_negative() {
("-", str.len() - 1, &str[1..])
} else {
("", str.len(), &str[..])
};
let scale = scale as usize;
if unsigned_len <= scale {
format!("{}0.{:0>scale$}", sign, unsigned_str)
} else {
str.insert(str.len() - scale, '.');
str
fn trim_decimal_trailing_zeros(mut string: String) -> String {
// Remove trailing zeros after the decimal point
if let Some(decimal_idx) = string.find('.') {
let after_decimal_idx = decimal_idx + 1;
let after = &mut string[after_decimal_idx..];
let trimmed_len = after.trim_end_matches('0').len();
string.truncate(after_decimal_idx + trimmed_len);
if string.ends_with('.') {
string.pop();
}
}
}

#[cfg(test)]
mod tests {
use super::big_decimal_to_str;
use bigdecimal::{num_bigint::BigInt, BigDecimal};

macro_rules! assert_decimal_str_eq {
($integer:expr, $scale:expr, $expected:expr) => {
assert_eq!(
big_decimal_to_str(BigDecimal::from_bigint(
BigInt::from($integer),
$scale
)),
$expected
);
};
}

#[test]
fn test_big_decimal_to_str() {
assert_decimal_str_eq!(11, 3, "0.011");
assert_decimal_str_eq!(11, 2, "0.11");
assert_decimal_str_eq!(11, 1, "1.1");
assert_decimal_str_eq!(11, 0, "11");
assert_decimal_str_eq!(11, -1, "110");
assert_decimal_str_eq!(0, 0, "0");

// Negative cases
assert_decimal_str_eq!(-11, 3, "-0.011");
assert_decimal_str_eq!(-11, 2, "-0.11");
assert_decimal_str_eq!(-11, 1, "-1.1");
assert_decimal_str_eq!(-11, 0, "-11");
assert_decimal_str_eq!(-11, -1, "-110");

// Round to 12 decimal places
// 1.0000000000011 -> 1.000000000001
assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, "1.000000000001");
}
string
}
Loading

0 comments on commit 9942c60

Please sign in to comment.