diff --git a/Cargo.lock b/Cargo.lock index ed5fbf3..c268f9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,12 +112,46 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + +[[package]] +name = "derive-visitor" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d47165df83b9707cbada3216607a5d66125b6a66906de0bc1216c0669767ca9e" +dependencies = [ + "derive-visitor-macros", +] + +[[package]] +name = "derive-visitor-macros" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "427b39a85fecafea16b1a5f3f50437151022e35eb4fe038107f08adbf7f8def6" +dependencies = [ + "convert_case", + "itertools", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "dyn-clone" version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + [[package]] name = "errno" version = "0.3.9" @@ -166,6 +200,15 @@ dependencies = [ "ahash", ] +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -240,6 +283,7 @@ dependencies = [ "arbitrary", "borsh", "bytemuck", + "derive-visitor", "num-cmp", "num-traits", "proptest", diff --git a/Cargo.toml b/Cargo.toml index 5f32724..aa52e16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ rust-version = "1.60" arbitrary = { version = "1.0.0", optional = true } borsh = { version = "1.2.0", optional = true, default-features = false } bytemuck = { version = "1.12.2", optional = true, default-features = false } +derive-visitor = { version = "0.4.0", optional = true } num-cmp = { version = "0.1.0", optional = true } num-traits = { version = "0.2.1", default-features = false } proptest = { version = "1.0.0", optional = true } diff --git a/src/lib.rs b/src/lib.rs index f383c8d..68c3a05 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -105,6 +105,71 @@ fn canonicalize_signed_zero(x: T) -> T { #[repr(transparent)] pub struct OrderedFloat(pub T); +#[cfg(feature = "derive-visitor")] +mod impl_derive_visitor { + use crate::OrderedFloat; + use derive_visitor::{Drive, DriveMut, Event, Visitor, VisitorMut}; + + impl Drive for OrderedFloat { + fn drive(&self, visitor: &mut V) { + visitor.visit(self, Event::Enter); + visitor.visit(self, Event::Exit); + } + } + + impl DriveMut for OrderedFloat { + fn drive_mut(&mut self, visitor: &mut V) { + visitor.visit(self, Event::Enter); + visitor.visit(self, Event::Exit); + } + } + + #[test] + pub fn test_derive_visitor() { + #[derive(Debug, Clone, PartialEq, Eq, Drive, DriveMut)] + pub enum Literal { + Null, + Float(OrderedFloat), + } + + #[derive(Visitor, VisitorMut)] + #[visitor(Literal(enter))] + struct FloatExpr(bool); + + impl FloatExpr { + fn enter_literal(&mut self, lit: &Literal) { + if let Literal::Float(_) = lit { + self.0 = true; + } + } + } + + assert!({ + let mut visitor = FloatExpr(false); + Literal::Null.drive(&mut visitor); + !visitor.0 + }); + + assert!({ + let mut visitor = FloatExpr(false); + Literal::Null.drive_mut(&mut visitor); + !visitor.0 + }); + + assert!({ + let mut visitor = FloatExpr(false); + Literal::Float(OrderedFloat(0.0)).drive(&mut visitor); + visitor.0 + }); + + assert!({ + let mut visitor = FloatExpr(false); + Literal::Float(OrderedFloat(0.0)).drive_mut(&mut visitor); + visitor.0 + }); + } +} + #[cfg(feature = "num-cmp")] mod impl_num_cmp { use super::OrderedFloat;