diff --git a/Cargo.toml b/Cargo.toml index 76bc50d59a08..312624c769c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -122,6 +122,7 @@ datafusion-proto = { path = "datafusion/proto", version = "43.0.0" } datafusion-proto-common = { path = "datafusion/proto-common", version = "43.0.0" } datafusion-sql = { path = "datafusion/sql", version = "43.0.0" } doc-comment = "0.3" +enumset = "1.1.5" env_logger = "0.11" futures = "0.3" half = { version = "2.2.1", default-features = false } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 04b0b0d22cfd..65ba70eb8dcd 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -220,7 +220,7 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.15.1", + "hashbrown 0.15.2", "num", ] @@ -406,9 +406,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.17" +version = "0.4.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" +checksum = "df895a515f70646414f4b45c0b79082783b80552b373a68283012928df56f522" dependencies = [ "bzip2", "flate2", @@ -814,9 +814,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.4" +version = "1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" +checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e" dependencies = [ "arrayref", "arrayvec", @@ -880,9 +880,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "bytes-utils" @@ -917,9 +917,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" +checksum = "f34d93e62b03caf570cccc334cbc6c2fceca82f39211051345108adcba3eebdc" dependencies = [ "jobserver", "libc", @@ -1080,6 +1080,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1097,9 +1107,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" dependencies = [ "libc", ] @@ -1166,6 +1176,40 @@ dependencies = [ "syn", ] +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn", +] + [[package]] name = "dary_heap" version = "0.3.7" @@ -1216,6 +1260,7 @@ dependencies = [ "datafusion-physical-optimizer", "datafusion-physical-plan", "datafusion-sql", + "enumset", "flate2", "futures", "glob", @@ -1349,6 +1394,7 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-functions-window-common", "datafusion-physical-expr-common", + "enumset", "indexmap", "paste", "recursive", @@ -1487,6 +1533,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", + "enumset", "indexmap", "itertools", "log", @@ -1669,6 +1716,27 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +[[package]] +name = "enumset" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d07a4b049558765cef5f0c1a273c3fc57084d768b44d2f98127aef4cceb17293" +dependencies = [ + "enumset_derive", +] + +[[package]] +name = "enumset_derive" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59c3b24c345d8c314966bdc1832f6c2635bfcce8e7cf363bd115987bba2ee242" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "env_filter" version = "0.1.2" @@ -1700,12 +1768,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1972,9 +2040,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.1" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" [[package]] name = "heck" @@ -1988,12 +2056,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" - [[package]] name = "hex" version = "0.4.3" @@ -2162,8 +2224,8 @@ dependencies = [ "http 1.1.0", "hyper 1.5.1", "hyper-util", - "rustls 0.23.17", - "rustls-native-certs 0.8.0", + "rustls 0.23.19", + "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", @@ -2330,6 +2392,12 @@ dependencies = [ "syn", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "1.0.3" @@ -2358,7 +2426,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.15.1", + "hashbrown 0.15.2", ] [[package]] @@ -2390,9 +2458,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "540654e97a3f4470a492cd30ff187bc95d89557a903a2bbf112e2fae98104ef2" +checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "jobserver" @@ -2405,9 +2473,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.72" +version = "0.3.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" +checksum = "fb15147158e79fd8b8afd0252522769c4f48725460b37338544d8379d94fc8f9" dependencies = [ "wasm-bindgen", ] @@ -2484,9 +2552,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.164" +version = "0.2.167" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" +checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" [[package]] name = "libflate" @@ -2546,9 +2614,9 @@ checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "litemap" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" +checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" [[package]] name = "lock_api" @@ -2628,11 +2696,10 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ - "hermit-abi", "libc", "wasi", "windows-sys 0.52.0", @@ -2862,7 +2929,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.15.1", + "hashbrown 0.15.2", "lz4_flex", "num", "num-bigint", @@ -3020,9 +3087,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.89" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] @@ -3063,7 +3130,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.17", + "rustls 0.23.19", "socket2", "thiserror 2.0.3", "tokio", @@ -3081,7 +3148,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls 0.23.17", + "rustls 0.23.19", "rustls-pki-types", "slab", "thiserror 2.0.3", @@ -3259,8 +3326,8 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.17", - "rustls-native-certs 0.8.0", + "rustls 0.23.19", + "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", @@ -3378,9 +3445,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.17" +version = "0.23.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e" +checksum = "934b404430bb06b3fae2cba809eb45a1ab1aecd64491213d7c3301b88393f8d1" dependencies = [ "once_cell", "ring", @@ -3399,20 +3466,19 @@ dependencies = [ "openssl-probe", "rustls-pemfile 1.0.4", "schannel", - "security-framework", + "security-framework 2.11.1", ] [[package]] name = "rustls-native-certs" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" dependencies = [ "openssl-probe", - "rustls-pemfile 2.2.0", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.0.1", ] [[package]] @@ -3538,7 +3604,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.6.0", - "core-foundation", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1415a607e92bec364ea2cf9264646dcce0f91e6d65281bd6f2819cca3bf39c8" +dependencies = [ + "bitflags 2.6.0", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -3686,9 +3765,9 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", "windows-sys 0.52.0", @@ -3801,9 +3880,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.87" +version = "2.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" dependencies = [ "proc-macro2", "quote", @@ -4009,7 +4088,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.17", + "rustls 0.23.19", "rustls-pki-types", "tokio", ] @@ -4052,9 +4131,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.40" +version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -4063,9 +4142,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", @@ -4074,9 +4153,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", ] @@ -4155,9 +4234,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.3" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" dependencies = [ "form_urlencoded", "idna", @@ -4246,9 +4325,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.95" +version = "0.2.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" +checksum = "21d3b25c3ea1126a2ad5f4f9068483c2af1e64168f847abe863a526b8dbfe00b" dependencies = [ "cfg-if", "once_cell", @@ -4257,9 +4336,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.95" +version = "0.2.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" +checksum = "52857d4c32e496dc6537646b5b117081e71fd2ff06de792e3577a150627db283" dependencies = [ "bumpalo", "log", @@ -4272,21 +4351,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.45" +version = "0.4.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" +checksum = "951fe82312ed48443ac78b66fa43eded9999f738f6022e67aead7b708659e49a" dependencies = [ "cfg-if", "js-sys", + "once_cell", "wasm-bindgen", "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.95" +version = "0.2.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" +checksum = "920b0ffe069571ebbfc9ddc0b36ba305ef65577c94b06262ed793716a1afd981" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4294,9 +4374,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.95" +version = "0.2.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" +checksum = "bf59002391099644be3524e23b781fa43d2be0c5aa0719a18c0731b9d195cab6" dependencies = [ "proc-macro2", "quote", @@ -4307,9 +4387,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.95" +version = "0.2.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" +checksum = "e5047c5392700766601942795a436d7d2599af60dcc3cc1248c9120bfb0827b0" [[package]] name = "wasm-streams" @@ -4326,9 +4406,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.72" +version = "0.3.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" +checksum = "476364ff87d0ae6bfb661053a9104ab312542658c3d8f963b7ace80b6f9b26b9" dependencies = [ "js-sys", "wasm-bindgen", @@ -4578,9 +4658,9 @@ dependencies = [ [[package]] name = "yoke" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" dependencies = [ "serde", "stable_deref_trait", @@ -4590,9 +4670,9 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", @@ -4623,18 +4703,18 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 268e0fb17f7b..c0b18cc05005 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -110,6 +110,7 @@ datafusion-physical-expr-common = { workspace = true } datafusion-physical-optimizer = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } +enumset = { workspace = true } flate2 = { version = "1.0.24", optional = true } futures = { workspace = true } glob = "0.3.0" diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 8b313d13e83e..12021039c1e1 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -69,6 +69,7 @@ use arrow::{ util::pretty::pretty_format_batches, }; use async_trait::async_trait; +use enumset::enum_set; use futures::{Stream, StreamExt}; use datafusion::execution::session_state::SessionStateBuilder; @@ -97,8 +98,8 @@ use datafusion::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{FetchType, Projection, SortExpr}; -use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; /// Execute the specified sql and return the resulting record batches @@ -343,10 +344,6 @@ impl OptimizerRule for TopKOptimizerRule { "topk" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn supports_rewrite(&self) -> bool { true } @@ -357,38 +354,47 @@ impl OptimizerRule for TopKOptimizerRule { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result, DataFusionError> { - // Note: this code simply looks for the pattern of a Limit followed by a - // Sort and replaces it by a TopK node. It does not handle many - // edge cases (e.g multiple sort columns, sort ASC / DESC), etc. - let LogicalPlan::Limit(ref limit, _) = plan else { - return Ok(Transformed::no(plan)); - }; - let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else { - return Ok(Transformed::no(plan)); - }; + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_all_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanLimit + | LogicalPlanPattern::LogicalPlanSort + )) { + return Ok(Transformed::jump(plan)); + } - if let LogicalPlan::Sort( - Sort { - ref expr, - ref input, - .. - }, - _, - ) = limit.input.as_ref() - { - if expr.len() == 1 { - // we found a sort with a single sort expr, replace with a a TopK - return Ok(Transformed::yes(LogicalPlan::extension(Extension { - node: Arc::new(TopKPlanNode { - k: fetch, - input: input.as_ref().clone(), - expr: expr[0].clone(), - }), - }))); + // Note: this code simply looks for the pattern of a Limit followed by a + // Sort and replaces it by a TopK node. It does not handle many + // edge cases (e.g multiple sort columns, sort ASC / DESC), etc. + let LogicalPlan::Limit(ref limit, _) = plan else { + return Ok(Transformed::no(plan)); + }; + let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else { + return Ok(Transformed::no(plan)); + }; + + if let LogicalPlan::Sort( + Sort { + ref expr, + ref input, + .. + }, + _, + ) = limit.input.as_ref() + { + if expr.len() == 1 { + // we found a sort with a single sort expr, replace with a a TopK + return Ok(Transformed::yes(LogicalPlan::extension(Extension { + node: Arc::new(TopKPlanNode { + k: fetch, + input: input.as_ref().clone(), + expr: expr[0].clone(), + }), + }))); + } } - } - Ok(Transformed::no(plan)) + Ok(Transformed::no(plan)) + }) } } diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 438662e0642b..405710f2c019 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -49,6 +49,7 @@ datafusion-expr-common = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } +enumset = { workspace = true } indexmap = { workspace = true } paste = "^1.0" recursive = { workspace = true } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 154e880c955c..88c654229d50 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -29,7 +29,7 @@ use crate::utils::expr_to_columns; use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; -use crate::logical_plan::tree_node::LogicalPlanStats; +use crate::logical_plan::tree_node::{LogicalPlanPattern, LogicalPlanStats}; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cse::HashNode; use datafusion_common::tree_node::{ @@ -39,6 +39,7 @@ use datafusion_common::{ plan_err, Column, DFSchema, HashMap, Result, ScalarValue, TableReference, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use enumset::enum_set; use sqlparser::ast::{ display_comma_separated, ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem, NullTreatment, RenameSelectItem, ReplaceSelectElement, @@ -1775,7 +1776,8 @@ impl Expr { } pub fn binary_expr(binary_expr: BinaryExpr) -> Self { - let stats = binary_expr.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprBinaryExpr)) + .merge(binary_expr.stats()); Expr::BinaryExpr(binary_expr, stats) } @@ -1785,7 +1787,8 @@ impl Expr { } pub fn _like(like: Like) -> Self { - let stats = like.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprLike)) + .merge(like.stats()); Expr::Like(like, stats) } @@ -1795,27 +1798,33 @@ impl Expr { } pub fn in_subquery(in_subquery: InSubquery) -> Self { - let stats = in_subquery.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprInSubquery)) + .merge(in_subquery.stats()); Expr::InSubquery(in_subquery, stats) } pub fn scalar_subquery(subquery: Subquery) -> Self { - let stats = subquery.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprScalarSubquery)) + .merge(subquery.stats()); Expr::ScalarSubquery(subquery, stats) } pub fn _not(expr: Box) -> Self { - let stats = expr.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprNot)) + .merge(expr.stats()); Expr::Not(expr, stats) } pub fn _is_not_null(expr: Box) -> Self { - let stats = expr.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprIsNotNull)) + .merge(expr.stats()); Expr::IsNotNull(expr, stats) } pub fn _is_null(expr: Box) -> Self { - let stats = expr.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprIsNull)) + .merge(expr.stats()); Expr::IsNull(expr, stats) } @@ -1830,7 +1839,8 @@ impl Expr { } pub fn _is_unknown(expr: Box) -> Self { - let stats = expr.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprIsUnknown)) + .merge(expr.stats()); Expr::IsUnknown(expr, stats) } @@ -1845,47 +1855,60 @@ impl Expr { } pub fn _is_not_unknown(expr: Box) -> Self { - let stats = expr.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprIsNotUnknown)) + .merge(expr.stats()); Expr::IsNotUnknown(expr, stats) } pub fn negative(expr: Box) -> Self { - let stats = expr.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprNegative)) + .merge(expr.stats()); Expr::Negative(expr, stats) } pub fn _between(between: Between) -> Self { - let stats = between.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprBetween)) + .merge(between.stats()); Expr::Between(between, stats) } pub fn case(case: Case) -> Self { - let stats = case.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprCase)) + .merge(case.stats()); Expr::Case(case, stats) } pub fn cast(cast: Cast) -> Self { - let stats = cast.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprCast)) + .merge(cast.stats()); Expr::Cast(cast, stats) } pub fn try_cast(try_cast: TryCast) -> Self { - let stats = try_cast.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprTryCast)) + .merge(try_cast.stats()); Expr::TryCast(try_cast, stats) } pub fn scalar_function(scalar_function: ScalarFunction) -> Self { - let stats = scalar_function.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprScalarFunction)) + .merge(scalar_function.stats()); Expr::ScalarFunction(scalar_function, stats) } pub fn window_function(window_function: WindowFunction) -> Self { - let stats = window_function.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprWindowFunction)) + .merge(window_function.stats()); Expr::WindowFunction(window_function, stats) } pub fn aggregate_function(aggregate_function: AggregateFunction) -> Self { - let stats = aggregate_function.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprAggregateFunction)) + .merge(aggregate_function.stats()); Expr::AggregateFunction(aggregate_function, stats) } @@ -1895,22 +1918,24 @@ impl Expr { } pub fn _in_list(in_list: InList) -> Self { - let stats = in_list.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprInList)) + .merge(in_list.stats()); Expr::InList(in_list, stats) } pub fn exists(exists: Exists) -> Self { - let stats = exists.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprExists)) + .merge(exists.stats()); Expr::Exists(exists, stats) } pub fn literal(scalar_value: ScalarValue) -> Self { - let stats = LogicalPlanStats::empty(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprLiteral)); Expr::Literal(scalar_value, stats) } pub fn column(column: Column) -> Self { - let stats = LogicalPlanStats::empty(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprColumn)); Expr::Column(column, stats) } @@ -1920,7 +1945,7 @@ impl Expr { } pub fn placeholder(placeholder: Placeholder) -> Self { - let stats = LogicalPlanStats::empty(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::ExprPlaceholder)); Expr::Placeholder(placeholder, stats) } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6c292a11c523..40a6be46f05d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -17,6 +17,7 @@ //! Logical plan types +use std::cell::Cell; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; @@ -54,11 +55,12 @@ use datafusion_common::{ FunctionalDependencies, ParamValues, Result, ScalarValue, TableReference, UnnestOptions, }; +use enumset::enum_set; use indexmap::IndexSet; // backwards compatibility use crate::display::PgJsonVisitor; -use crate::logical_plan::tree_node::LogicalPlanStats; +use crate::logical_plan::tree_node::{LogicalPlanPattern, LogicalPlanStats}; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -486,6 +488,13 @@ impl LogicalPlan { let mut using_columns: Vec> = vec![]; self.apply_with_subqueries(|plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanJoin) + { + return Ok(TreeNodeRecursion::Jump); + } + if let LogicalPlan::Join( Join { join_constraint: JoinConstraint::Using, @@ -1505,31 +1514,71 @@ impl LogicalPlan { self, param_values: &ParamValues, ) -> Result { - self.transform_up_with_subqueries(|plan| { - let schema = Arc::clone(plan.schema()); - let name_preserver = NamePreserver::new(&plan); - plan.map_expressions(|e| { - let (e, has_placeholder) = e.infer_placeholder_types(&schema)?; - if !has_placeholder { - // Performance optimization: - // avoid NamePreserver copy and second pass over expression - // if no placeholders. - Ok(Transformed::no(e)) - } else { - let original_name = name_preserver.save(&e); - let transformed_expr = e.transform_up(|e| { - if let Expr::Placeholder(Placeholder { id, .. }, _) = e { - let value = param_values.get_placeholders_with_values(&id)?; - Ok(Transformed::yes(Expr::literal(value))) - } else { - Ok(Transformed::no(e)) - } - })?; - // Preserve name to avoid breaking column references to this expression - Ok(transformed_expr.update_data(|expr| original_name.restore(expr))) + let skip = Cell::new(false); + self.transform_down_up_with_subqueries( + |plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::ExprPlaceholder) + { + skip.set(true); + return Ok(Transformed::jump(plan)); } - }) - }) + + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } + + let schema = Arc::clone(plan.schema()); + let name_preserver = NamePreserver::new(&plan); + plan.map_expressions(|e| { + let (e, has_placeholder) = e.infer_placeholder_types(&schema)?; + if !has_placeholder { + // Performance optimization: + // avoid NamePreserver copy and second pass over expression + // if no placeholders. + Ok(Transformed::no(e)) + } else { + let original_name = name_preserver.save(&e); + let skip = Cell::new(false); + let transformed_expr = e.transform_down_up( + |e| { + if !e + .stats() + .contains_pattern(LogicalPlanPattern::ExprPlaceholder) + { + skip.set(true); + return Ok(Transformed::jump(e)); + } + + Ok(Transformed::no(e)) + }, + |e| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(e)); + } + + if let Expr::Placeholder(Placeholder { id, .. }, _) = e { + let value = + param_values.get_placeholders_with_values(&id)?; + Ok(Transformed::yes(Expr::literal(value))) + } else { + Ok(Transformed::no(e)) + } + }, + )?; + // Preserve name to avoid breaking column references to this expression + Ok(transformed_expr + .update_data(|expr| original_name.restore(expr))) + } + }) + }, + ) .map(|res| res.data) } @@ -1537,8 +1586,22 @@ impl LogicalPlan { pub fn get_parameter_names(&self) -> Result> { let mut param_names = HashSet::new(); self.apply_with_subqueries(|plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::ExprPlaceholder) + { + return Ok(TreeNodeRecursion::Jump); + } + plan.apply_expressions(|expr| { expr.apply(|expr| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::ExprPlaceholder) + { + return Ok(TreeNodeRecursion::Jump); + } + if let Expr::Placeholder(Placeholder { id, .. }, _) = expr { param_names.insert(id.clone()); } @@ -1556,8 +1619,22 @@ impl LogicalPlan { let mut param_types: HashMap> = HashMap::new(); self.apply_with_subqueries(|plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::ExprPlaceholder) + { + return Ok(TreeNodeRecursion::Jump); + } + plan.apply_expressions(|expr| { expr.apply(|expr| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::ExprPlaceholder) + { + return Ok(TreeNodeRecursion::Jump); + } + if let Expr::Placeholder(Placeholder { id, data_type }, _) = expr { let prev = param_types.get(id); match (prev, data_type) { @@ -2085,12 +2162,16 @@ impl LogicalPlan { } pub fn projection(projection: Projection) -> Self { - let stats = projection.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanProjection)) + .merge(projection.stats()); LogicalPlan::Projection(projection, stats) } pub fn filter(filter: Filter) -> Self { - let stats = filter.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanFilter)) + .merge(filter.stats()); LogicalPlan::Filter(filter, stats) } @@ -2100,32 +2181,42 @@ impl LogicalPlan { } pub fn window(window: Window) -> Self { - let stats = window.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanWindow)) + .merge(window.stats()); LogicalPlan::Window(window, stats) } pub fn aggregate(aggregate: Aggregate) -> Self { - let stats = aggregate.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanAggregate)) + .merge(aggregate.stats()); LogicalPlan::Aggregate(aggregate, stats) } pub fn sort(sort: Sort) -> Self { - let stats = sort.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanSort)) + .merge(sort.stats()); LogicalPlan::Sort(sort, stats) } pub fn join(join: Join) -> Self { - let stats = join.stats(); + let stats = LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanJoin)) + .merge(join.stats()); LogicalPlan::Join(join, stats) } pub fn repartition(repartition: Repartition) -> Self { - let stats = repartition.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanRepartition)) + .merge(repartition.stats()); LogicalPlan::Repartition(repartition, stats) } pub fn union(projection: Union) -> Self { - let stats = projection.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanUnion)) + .merge(projection.stats()); LogicalPlan::Union(projection, stats) } @@ -2140,12 +2231,17 @@ impl LogicalPlan { } pub fn subquery_alias(subquery_alias: SubqueryAlias) -> Self { - let stats = subquery_alias.stats(); + let stats = LogicalPlanStats::new(enum_set!( + LogicalPlanPattern::LogicalPlanSubqueryAlias + )) + .merge(subquery_alias.stats()); LogicalPlan::SubqueryAlias(subquery_alias, stats) } pub fn limit(limit: Limit) -> Self { - let stats = limit.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanLimit)) + .merge(limit.stats()); LogicalPlan::Limit(limit, stats) } @@ -2170,7 +2266,9 @@ impl LogicalPlan { } pub fn distinct(distinct: Distinct) -> Self { - let stats = distinct.stats(); + let stats = + LogicalPlanStats::new(enum_set!(LogicalPlanPattern::LogicalPlanDistinct)) + .merge(distinct.stats()); LogicalPlan::Distinct(distinct, stats) } @@ -2205,7 +2303,10 @@ impl LogicalPlan { } pub fn empty_relation(empty_relation: EmptyRelation) -> Self { - let stats = LogicalPlanStats::empty(); + let stats = LogicalPlanStats::new(enum_set!( + LogicalPlanPattern::LogicalPlanEmptyRelation + )) + .merge(LogicalPlanStats::empty()); LogicalPlan::EmptyRelation(empty_relation, stats) } } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 8e472a1c737e..b22a6158af72 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -53,18 +53,107 @@ use datafusion_common::tree_node::{ TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{internal_err, Result}; +use enumset::{enum_set, EnumSet, EnumSetType}; + +#[derive(EnumSetType, Debug)] +pub enum LogicalPlanPattern { + // [`Expr`] nodes + // ExprAlias, + ExprColumn, + // ExprScalarVariable, + ExprLiteral, + ExprBinaryExpr, + ExprLike, + // ExprSimilarTo, + ExprNot, + ExprIsNotNull, + ExprIsNull, + // ExprIsTrue, + // ExprIsFalse, + ExprIsUnknown, + // ExprIsNotTrue, + // ExprIsNotFalse, + ExprIsNotUnknown, + ExprNegative, + // ExprGetIndexedField, + ExprBetween, + ExprCase, + ExprCast, + ExprTryCast, + ExprScalarFunction, + ExprAggregateFunction, + ExprWindowFunction, + ExprInList, + ExprExists, + ExprInSubquery, + ExprScalarSubquery, + // ExprWildcard, + // ExprGroupingSet, + ExprPlaceholder, + // ExprOuterReferenceColumn, + // ExprUnnest, + + // [`LogicalPlan`] nodes + LogicalPlanProjection, + LogicalPlanFilter, + LogicalPlanWindow, + LogicalPlanAggregate, + LogicalPlanSort, + LogicalPlanJoin, + // LogicalPlanCrossJoin, + LogicalPlanRepartition, + LogicalPlanUnion, + // LogicalPlanTableScan, + LogicalPlanEmptyRelation, + // LogicalPlanSubquery, + LogicalPlanSubqueryAlias, + LogicalPlanLimit, + // LogicalPlanStatement, + // LogicalPlanValues, + // LogicalPlanExplain, + // LogicalPlanAnalyze, + // LogicalPlanExtension, + LogicalPlanDistinct, + // LogicalPlanDml, + // LogicalPlanDdl, + // LogicalPlanCopy, + // LogicalPlanDescribeTable, + // LogicalPlanUnnest, + // LogicalPlanRecursiveQuery, +} #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] -pub struct LogicalPlanStats {} +pub struct LogicalPlanStats { + patterns: EnumSet, +} impl LogicalPlanStats { + pub(crate) fn new(patterns: EnumSet) -> Self { + Self { patterns } + } + pub(crate) fn empty() -> Self { - Self {} + Self { + patterns: EnumSet::empty(), + } } - pub(crate) fn merge(self, _other: LogicalPlanStats) -> Self { + pub(crate) fn merge(mut self, other: LogicalPlanStats) -> Self { + self.patterns.insert_all(other.patterns); self } + + pub fn contains_pattern(&self, pattern: LogicalPlanPattern) -> bool { + self.patterns.contains(pattern) + } + + pub fn contains_all_patterns(&self, patterns: EnumSet) -> bool { + self.patterns.is_superset(patterns) + } + + pub fn contains_any_patterns(&self, patterns: EnumSet) -> bool { + !self.patterns.is_disjoint(patterns) + } } impl TreeNode for LogicalPlan { @@ -912,16 +1001,26 @@ impl LogicalPlan { mut f: F, ) -> Result { self.apply_expressions(|expr| { - expr.apply(|expr| match expr { - Expr::Exists(Exists { subquery, .. }, _) - | Expr::InSubquery(InSubquery { subquery, .. }, _) - | Expr::ScalarSubquery(subquery, _) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - f(&LogicalPlan::subquery(subquery.clone())) + expr.apply(|expr| { + if !expr.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprExists + | LogicalPlanPattern::ExprInSubquery + | LogicalPlanPattern::ExprScalarSubquery + )) { + return Ok(TreeNodeRecursion::Jump); + } + + match expr { + Expr::Exists(Exists { subquery, .. }, _) + | Expr::InSubquery(InSubquery { subquery, .. }, _) + | Expr::ScalarSubquery(subquery, _) => { + // use a synthetic plan so the collector sees a + // LogicalPlan::Subquery (even though it is + // actually a Subquery alias) + f(&LogicalPlan::subquery(subquery.clone())) + } + _ => Ok(TreeNodeRecursion::Continue), } - _ => Ok(TreeNodeRecursion::Continue), }) }) } @@ -935,40 +1034,51 @@ impl LogicalPlan { mut f: F, ) -> Result> { self.map_expressions(|expr| { - expr.transform_down(|expr| match expr { - Expr::Exists(Exists { subquery, negated }, _) => { - f(LogicalPlan::subquery(subquery))?.map_data(|s| match s { - LogicalPlan::Subquery(subquery, _) => { - Ok(Expr::exists(Exists { subquery, negated })) - } - _ => internal_err!("Transformation should return Subquery"), - }) + expr.transform_down(|expr| { + if !expr.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprExists + | LogicalPlanPattern::ExprInSubquery + | LogicalPlanPattern::ExprScalarSubquery + )) { + return Ok(Transformed::jump(expr)); } - Expr::InSubquery( - InSubquery { - expr, - subquery, - negated, - }, - _, - ) => f(LogicalPlan::subquery(subquery))?.map_data(|s| match s { - LogicalPlan::Subquery(subquery, _) => { - Ok(Expr::in_subquery(InSubquery { + + match expr { + Expr::Exists(Exists { subquery, negated }, _) => { + f(LogicalPlan::subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery, _) => { + Ok(Expr::exists(Exists { subquery, negated })) + } + _ => internal_err!("Transformation should return Subquery"), + }) + } + Expr::InSubquery( + InSubquery { expr, subquery, negated, - })) - } - _ => internal_err!("Transformation should return Subquery"), - }), - Expr::ScalarSubquery(subquery, _) => f(LogicalPlan::subquery(subquery))? - .map_data(|s| match s { + }, + _, + ) => f(LogicalPlan::subquery(subquery))?.map_data(|s| match s { LogicalPlan::Subquery(subquery, _) => { - Ok(Expr::scalar_subquery(subquery)) + Ok(Expr::in_subquery(InSubquery { + expr, + subquery, + negated, + })) } _ => internal_err!("Transformation should return Subquery"), }), - _ => Ok(Transformed::no(expr)), + Expr::ScalarSubquery(subquery, _) => { + f(LogicalPlan::subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery, _) => { + Ok(Expr::scalar_subquery(subquery)) + } + _ => internal_err!("Transformation should return Subquery"), + }) + } + _ => Ok(Transformed::no(expr)), + } }) }) } diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index c0f17de6c5c5..a5380b720bc6 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -42,6 +42,7 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } +enumset = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } log = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 0727e9c8cb98..4ada14ca6df7 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -16,12 +16,15 @@ // under the License. use crate::analyzer::AnalyzerRule; +use enumset::enum_set; +use std::cell::Cell; use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_expr::expr::{AggregateFunction, WindowFunction}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; @@ -39,7 +42,61 @@ impl CountWildcardRule { impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_down_with_subqueries(analyze_internal).data() + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprWindowFunction + | LogicalPlanPattern::ExprAggregateFunction + )) { + return Ok(Transformed::jump(plan)); + } + + let name_preserver = NamePreserver::new(&plan); + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr); + let skip = Cell::new(false); + let transformed_expr = expr.transform_down_up( + |expr| { + if !expr.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprWindowFunction + | LogicalPlanPattern::ExprAggregateFunction + )) { + skip.set(true); + return Ok(Transformed::jump(expr)); + } + + Ok(Transformed::no(expr)) + }, + |expr| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(expr)); + } + + match expr { + Expr::WindowFunction(mut window_function, _) + if is_count_star_window_aggregate(&window_function) => + { + window_function.args = vec![lit(COUNT_STAR_EXPANSION)]; + Ok(Transformed::yes(Expr::window_function( + window_function, + ))) + } + Expr::AggregateFunction(mut aggregate_function, _) + if is_count_star_aggregate(&aggregate_function) => + { + aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)]; + Ok(Transformed::yes(Expr::aggregate_function( + aggregate_function, + ))) + } + _ => Ok(Transformed::no(expr)), + } + }, + )?; + Ok(transformed_expr.update_data(|data| original_name.restore(data))) + }) + }) + .data() } fn name(&self) -> &str { @@ -67,31 +124,6 @@ fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { if udaf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) } -fn analyze_internal(plan: LogicalPlan) -> Result> { - let name_preserver = NamePreserver::new(&plan); - plan.map_expressions(|expr| { - let original_name = name_preserver.save(&expr); - let transformed_expr = expr.transform_up(|expr| match expr { - Expr::WindowFunction(mut window_function, _) - if is_count_star_window_aggregate(&window_function) => - { - window_function.args = vec![lit(COUNT_STAR_EXPANSION)]; - Ok(Transformed::yes(Expr::window_function(window_function))) - } - Expr::AggregateFunction(mut aggregate_function, _) - if is_count_star_aggregate(&aggregate_function) => - { - aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)]; - Ok(Transformed::yes(Expr::aggregate_function( - aggregate_function, - ))) - } - _ => Ok(Transformed::no(expr)), - })?; - Ok(transformed_expr.update_data(|data| original_name.restore(data))) - }) -} - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs index 1ac32258d0db..82b95c6218e5 100644 --- a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs @@ -15,20 +15,22 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use crate::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{Column, Result}; use datafusion_expr::builder::validate_unique_names; use datafusion_expr::expr::{PlannedReplaceSelectItem, Wildcard}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::utils::{ expand_qualified_wildcard, expand_wildcard, find_base_plan, }; use datafusion_expr::{ Distinct, DistinctOn, Expr, LogicalPlan, Projection, SubqueryAlias, }; +use enumset::enum_set; +use std::cell::Cell; +use std::sync::Arc; #[derive(Default, Debug)] pub struct ExpandWildcardRule {} @@ -43,7 +45,60 @@ impl AnalyzerRule for ExpandWildcardRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { // Because the wildcard expansion is based on the schema of the input plan, // using `transform_up_with_subqueries` here. - plan.transform_up_with_subqueries(expand_internal).data() + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanProjection + | LogicalPlanPattern::LogicalPlanSubqueryAlias + | LogicalPlanPattern::LogicalPlanDistinct + )) { + skip.set(true); + return Ok(Transformed::jump(plan)); + } + + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } + + match plan { + LogicalPlan::Projection(Projection { expr, input, .. }, _) => { + let projected_expr = expand_exprlist(&input, expr)?; + validate_unique_names("Projections", projected_expr.iter())?; + Ok(Transformed::yes( + Projection::try_new(projected_expr, Arc::clone(&input)) + .map(LogicalPlan::projection)?, + )) + } + // The schema of the plan should also be updated if the child plan is transformed. + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }, _) => { + Ok(Transformed::yes( + SubqueryAlias::try_new(input, alias) + .map(LogicalPlan::subquery_alias)?, + )) + } + LogicalPlan::Distinct(Distinct::On(distinct_on), _) => { + let projected_expr = + expand_exprlist(&distinct_on.input, distinct_on.select_expr)?; + validate_unique_names("Distinct", projected_expr.iter())?; + Ok(Transformed::yes(LogicalPlan::distinct(Distinct::On( + DistinctOn::try_new( + distinct_on.on_expr, + projected_expr, + distinct_on.sort_expr, + distinct_on.input, + )?, + )))) + } + _ => Ok(Transformed::no(plan)), + } + }, + ) + .data() } fn name(&self) -> &str { @@ -51,39 +106,6 @@ impl AnalyzerRule for ExpandWildcardRule { } } -fn expand_internal(plan: LogicalPlan) -> Result> { - match plan { - LogicalPlan::Projection(Projection { expr, input, .. }, _) => { - let projected_expr = expand_exprlist(&input, expr)?; - validate_unique_names("Projections", projected_expr.iter())?; - Ok(Transformed::yes( - Projection::try_new(projected_expr, Arc::clone(&input)) - .map(LogicalPlan::projection)?, - )) - } - // The schema of the plan should also be updated if the child plan is transformed. - LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }, _) => { - Ok(Transformed::yes( - SubqueryAlias::try_new(input, alias).map(LogicalPlan::subquery_alias)?, - )) - } - LogicalPlan::Distinct(Distinct::On(distinct_on), _) => { - let projected_expr = - expand_exprlist(&distinct_on.input, distinct_on.select_expr)?; - validate_unique_names("Distinct", projected_expr.iter())?; - Ok(Transformed::yes(LogicalPlan::distinct(Distinct::On( - DistinctOn::try_new( - distinct_on.on_expr, - projected_expr, - distinct_on.sort_expr, - distinct_on.input, - )?, - )))) - } - _ => Ok(Transformed::no(plan)), - } -} - fn expand_exprlist(input: &LogicalPlan, expr: Vec) -> Result> { let mut projected_expr = vec![]; let input = find_base_plan(input); diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index afad2fae2ca8..9b7992efb2eb 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -17,11 +17,18 @@ //! [`Analyzer`] and [`AnalyzerRule`] +use enumset::enum_set; +use log::debug; use std::fmt::Debug; use std::sync::Arc; -use log::debug; - +use crate::analyzer::count_wildcard_rule::CountWildcardRule; +use crate::analyzer::expand_wildcard_rule::ExpandWildcardRule; +use crate::analyzer::inline_table_scan::InlineTableScan; +use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction; +use crate::analyzer::subquery::check_subquery_expr; +use crate::analyzer::type_coercion::TypeCoercion; +use crate::utils::log_plan; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; @@ -29,16 +36,9 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_expr::expr::Exists; use datafusion_expr::expr::InSubquery; use datafusion_expr::expr_rewriter::FunctionRewrite; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{Expr, LogicalPlan}; -use crate::analyzer::count_wildcard_rule::CountWildcardRule; -use crate::analyzer::expand_wildcard_rule::ExpandWildcardRule; -use crate::analyzer::inline_table_scan::InlineTableScan; -use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction; -use crate::analyzer::subquery::check_subquery_expr; -use crate::analyzer::type_coercion::TypeCoercion; -use crate::utils::log_plan; - use self::function_rewrite::ApplyFunctionRewrites; pub mod count_wildcard_rule; @@ -177,9 +177,25 @@ impl Analyzer { /// Do necessary check and fail the invalid plan fn check_plan(plan: &LogicalPlan) -> Result<()> { plan.apply_with_subqueries(|plan: &LogicalPlan| { + if !plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprExists + | LogicalPlanPattern::ExprInSubquery + | LogicalPlanPattern::ExprScalarSubquery + )) { + return Ok(TreeNodeRecursion::Jump); + } + plan.apply_expressions(|expr| { // recursively look for subqueries expr.apply(|expr| { + if !plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprExists + | LogicalPlanPattern::ExprInSubquery + | LogicalPlanPattern::ExprScalarSubquery + )) { + return Ok(TreeNodeRecursion::Jump); + } + match expr { Expr::Exists(Exists { subquery, .. }, _) | Expr::InSubquery(InSubquery { subquery, .. }, _) diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 2ff148e55987..3255570bdf8a 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -22,6 +22,7 @@ use recursive::recursive; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::utils::split_conjunction; use datafusion_expr::{Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window}; @@ -255,6 +256,13 @@ fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { let mut exprs = vec![]; inner_plan.apply_with_subqueries(|plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanFilter) + { + return Ok(TreeNodeRecursion::Jump); + } + if let LogicalPlan::Filter(Filter { predicate, .. }, _) = plan { let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) .into_iter() diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 4fca0dfd4c4a..8591d811fd9e 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -17,14 +17,16 @@ //! [`EliminateDuplicatedExpr`] Removes redundant expressions -use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{Aggregate, Expr, Sort, SortExpr}; +use enumset::enum_set; use indexmap::IndexSet; use std::hash::{Hash, Hasher}; + /// Optimization rule that eliminate duplicated expr. #[derive(Default, Debug)] pub struct EliminateDuplicatedExpr; @@ -49,10 +51,6 @@ impl Hash for SortExprWrapper { } } impl OptimizerRule for EliminateDuplicatedExpr { - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn supports_rewrite(&self) -> bool { true } @@ -62,51 +60,60 @@ impl OptimizerRule for EliminateDuplicatedExpr { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Sort(sort, _) => { - let len = sort.expr.len(); - let unique_exprs: Vec<_> = sort - .expr - .into_iter() - .map(SortExprWrapper) - .collect::>() - .into_iter() - .map(|wrapper| wrapper.0) - .collect(); + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanSort + | LogicalPlanPattern::LogicalPlanAggregate + )) { + return Ok(Transformed::jump(plan)); + } - let transformed = if len != unique_exprs.len() { - Transformed::yes - } else { - Transformed::no - }; + match plan { + LogicalPlan::Sort(sort, _) => { + let len = sort.expr.len(); + let unique_exprs: Vec<_> = sort + .expr + .into_iter() + .map(SortExprWrapper) + .collect::>() + .into_iter() + .map(|wrapper| wrapper.0) + .collect(); - Ok(transformed(LogicalPlan::sort(Sort { - expr: unique_exprs, - input: sort.input, - fetch: sort.fetch, - }))) - } - LogicalPlan::Aggregate(agg, _) => { - let len = agg.group_expr.len(); + let transformed = if len != unique_exprs.len() { + Transformed::yes + } else { + Transformed::no + }; + + Ok(transformed(LogicalPlan::sort(Sort { + expr: unique_exprs, + input: sort.input, + fetch: sort.fetch, + }))) + } + LogicalPlan::Aggregate(agg, _) => { + let len = agg.group_expr.len(); - let unique_exprs: Vec = agg - .group_expr - .into_iter() - .collect::>() - .into_iter() - .collect(); + let unique_exprs: Vec = agg + .group_expr + .into_iter() + .collect::>() + .into_iter() + .collect(); - let transformed = if len != unique_exprs.len() { - Transformed::yes - } else { - Transformed::no - }; + let transformed = if len != unique_exprs.len() { + Transformed::yes + } else { + Transformed::no + }; - Aggregate::try_new(agg.input, unique_exprs, agg.aggr_expr) - .map(|f| transformed(LogicalPlan::aggregate(f))) + Aggregate::try_new(agg.input, unique_exprs, agg.aggr_expr) + .map(|f| transformed(LogicalPlan::aggregate(f))) + } + _ => Ok(Transformed::no(plan)), } - _ => Ok(Transformed::no(plan)), - } + }) } fn name(&self) -> &str { "eliminate_duplicated_expr" diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 38d4a531444d..ffd983f8e602 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -17,14 +17,14 @@ //! [`EliminateFilter`] replaces `where false` or `where null` with an empty relation. +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{EmptyRelation, Expr, Filter, LogicalPlan}; +use enumset::enum_set; use std::sync::Arc; -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; - /// Optimization rule that eliminate the scalar value (true/false/null) filter /// with an [LogicalPlan::EmptyRelation] /// @@ -45,10 +45,6 @@ impl OptimizerRule for EliminateFilter { "eliminate_filter" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn supports_rewrite(&self) -> bool { true } @@ -58,25 +54,33 @@ impl OptimizerRule for EliminateFilter { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Filter( - Filter { - predicate: Expr::Literal(ScalarValue::Boolean(v), _), - input, - .. - }, - _, - ) => match v { - Some(true) => Ok(Transformed::yes(Arc::unwrap_or_clone(input))), - Some(false) | None => Ok(Transformed::yes(LogicalPlan::empty_relation( - EmptyRelation { - produce_one_row: false, - schema: Arc::clone(input.schema()), + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanFilter | LogicalPlanPattern::ExprLiteral + )) { + return Ok(Transformed::jump(plan)); + } + + match plan { + LogicalPlan::Filter( + Filter { + predicate: Expr::Literal(ScalarValue::Boolean(v), _), + input, + .. }, - ))), - }, - _ => Ok(Transformed::no(plan)), - } + _, + ) => match v { + Some(true) => Ok(Transformed::yes(Arc::unwrap_or_clone(input))), + Some(false) | None => Ok(Transformed::yes( + LogicalPlan::empty_relation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(input.schema()), + }), + )), + }, + _ => Ok(Transformed::no(plan)), + } + }) } } diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 6b06e7927266..bc116f5ae4e6 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -16,11 +16,13 @@ // under the License. //! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY` clause -use crate::optimizer::ApplyOrder; + use crate::{OptimizerConfig, OptimizerRule}; +use std::cell::Cell; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility}; /// Optimizer rule that removes constant expressions from `GROUP BY` clause @@ -45,50 +47,71 @@ impl OptimizerRule for EliminateGroupByConstant { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Aggregate(aggregate, _) => { - let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate - .group_expr - .iter() - .partition(|expr| is_constant_expression(expr)); - - // If no constant expressions found (nothing to optimize) or - // constant expression is the only expression in aggregate, - // optimization is skipped - if const_group_expr.is_empty() - || (!const_group_expr.is_empty() - && nonconst_group_expr.is_empty() - && aggregate.aggr_expr.is_empty()) + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanAggregate) { - return Ok(Transformed::no(LogicalPlan::aggregate(aggregate))); + skip.set(true); + return Ok(Transformed::jump(plan)); } - let simplified_aggregate = LogicalPlan::aggregate(Aggregate::try_new( - aggregate.input, - nonconst_group_expr.into_iter().cloned().collect(), - aggregate.aggr_expr.clone(), - )?); - - let projection_expr = - aggregate.group_expr.into_iter().chain(aggregate.aggr_expr); - - let projection = LogicalPlanBuilder::from(simplified_aggregate) - .project(projection_expr)? - .build()?; + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } - Ok(Transformed::yes(projection)) - } - _ => Ok(Transformed::no(plan)), - } + match plan { + LogicalPlan::Aggregate(aggregate, _) => { + let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = + aggregate + .group_expr + .iter() + .partition(|expr| is_constant_expression(expr)); + + // If no constant expressions found (nothing to optimize) or + // constant expression is the only expression in aggregate, + // optimization is skipped + if const_group_expr.is_empty() + || (!const_group_expr.is_empty() + && nonconst_group_expr.is_empty() + && aggregate.aggr_expr.is_empty()) + { + return Ok(Transformed::no(LogicalPlan::aggregate( + aggregate, + ))); + } + + let simplified_aggregate = + LogicalPlan::aggregate(Aggregate::try_new( + aggregate.input, + nonconst_group_expr.into_iter().cloned().collect(), + aggregate.aggr_expr.clone(), + )?); + + let projection_expr = + aggregate.group_expr.into_iter().chain(aggregate.aggr_expr); + + let projection = LogicalPlanBuilder::from(simplified_aggregate) + .project(projection_expr)? + .build()?; + + Ok(Transformed::yes(projection)) + } + _ => Ok(Transformed::no(plan)), + } + }, + ) } fn name(&self) -> &str { "eliminate_group_by_constant" } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } } /// Checks if expression is constant, and can be eliminated from group by. diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index e36c20d6c898..ba196d7da628 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -16,15 +16,17 @@ // under the License. //! [`EliminateJoin`] rewrites `INNER JOIN` with `true`/`null` -use crate::optimizer::ApplyOrder; + use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::JoinType::Inner; use datafusion_expr::{ logical_plan::{EmptyRelation, LogicalPlan}, Expr, }; +use enumset::enum_set; /// Eliminates joins when join condition is false. /// Replaces joins when inner join condition is true with a cross join. @@ -42,31 +44,37 @@ impl OptimizerRule for EliminateJoin { "eliminate_join" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn rewrite( &self, plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Join(join, _) - if join.join_type == Inner && join.on.is_empty() => - { - match join.filter { - Some(Expr::Literal(ScalarValue::Boolean(Some(false)), _)) => Ok( - Transformed::yes(LogicalPlan::empty_relation(EmptyRelation { - produce_one_row: false, - schema: join.schema, - })), - ), - _ => Ok(Transformed::no(LogicalPlan::join(join))), + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_all_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanJoin | LogicalPlanPattern::ExprLiteral + )) { + return Ok(Transformed::jump(plan)); + } + + match plan { + LogicalPlan::Join(join, _) + if join.join_type == Inner && join.on.is_empty() => + { + match join.filter { + Some(Expr::Literal(ScalarValue::Boolean(Some(false)), _)) => { + Ok(Transformed::yes(LogicalPlan::empty_relation( + EmptyRelation { + produce_one_row: false, + schema: join.schema, + }, + ))) + } + _ => Ok(Transformed::no(LogicalPlan::join(join))), + } } + _ => Ok(Transformed::no(plan)), } - _ => Ok(Transformed::no(plan)), - } + }) } fn supports_rewrite(&self) -> bool { diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index d47aa3a48ec1..bc5efe327303 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -16,11 +16,13 @@ // under the License. //! [`EliminateLimit`] eliminates `LIMIT` when possible -use crate::optimizer::ApplyOrder; + use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::logical_plan::{EmptyRelation, FetchType, LogicalPlan, SkipType}; +use std::cell::Cell; use std::sync::Arc; /// Optimizer rule to replace `LIMIT 0` or `LIMIT` whose ancestor LIMIT's skip is @@ -45,10 +47,6 @@ impl OptimizerRule for EliminateLimit { "eliminate_limit" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } - fn supports_rewrite(&self) -> bool { true } @@ -58,31 +56,53 @@ impl OptimizerRule for EliminateLimit { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result, datafusion_common::DataFusionError> { - match plan { - LogicalPlan::Limit(limit, _) => { - // Only supports rewriting for literal fetch - let FetchType::Literal(fetch) = limit.get_fetch_type()? else { - return Ok(Transformed::no(LogicalPlan::limit(limit))); - }; + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanLimit) + { + skip.set(true); + return Ok(Transformed::jump(plan)); + } + + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } + + match plan { + LogicalPlan::Limit(limit, _) => { + // Only supports rewriting for literal fetch + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::limit(limit))); + }; - if let Some(v) = fetch { - if v == 0 { - return Ok(Transformed::yes(LogicalPlan::empty_relation( - EmptyRelation { - produce_one_row: false, - schema: Arc::clone(limit.input.schema()), - }, - ))); + if let Some(v) = fetch { + if v == 0 { + return Ok(Transformed::yes( + LogicalPlan::empty_relation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(limit.input.schema()), + }), + )); + } + } else if matches!(limit.get_skip_type()?, SkipType::Literal(0)) { + // If fetch is `None` and skip is 0, then Limit takes no effect and + // we can remove it. Its input also can be Limit, so we should apply again. + return self + .rewrite(Arc::unwrap_or_clone(limit.input), _config); + } + Ok(Transformed::no(LogicalPlan::limit(limit))) } - } else if matches!(limit.get_skip_type()?, SkipType::Literal(0)) { - // If fetch is `None` and skip is 0, then Limit takes no effect and - // we can remove it. Its input also can be Limit, so we should apply again. - return self.rewrite(Arc::unwrap_or_clone(limit.input), _config); + _ => Ok(Transformed::no(plan)), } - Ok(Transformed::no(LogicalPlan::limit(limit))) - } - _ => Ok(Transformed::no(plan)), - } + }, + ) } } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 4979ddc2f3ac..9103a869ef9d 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -16,13 +16,15 @@ // under the License. //! [`EliminateNestedUnion`]: flattens nested `Union` to a single `Union` -use crate::optimizer::ApplyOrder; + use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{Distinct, LogicalPlan, Union}; use itertools::Itertools; +use std::cell::Cell; use std::sync::Arc; #[derive(Default, Debug)] @@ -41,10 +43,6 @@ impl OptimizerRule for EliminateNestedUnion { "eliminate_nested_union" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } - fn supports_rewrite(&self) -> bool { true } @@ -54,43 +52,69 @@ impl OptimizerRule for EliminateNestedUnion { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Union(Union { inputs, schema }, _) => { - let inputs = inputs - .into_iter() - .flat_map(extract_plans_from_union) - .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) - .collect::>>()?; - - Ok(Transformed::yes(LogicalPlan::union(Union { - inputs: inputs.into_iter().map(Arc::new).collect_vec(), - schema, - }))) - } - LogicalPlan::Distinct(Distinct::All(nested_plan), _) => { - match Arc::unwrap_or_clone(nested_plan) { + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanUnion) + { + skip.set(true); + return Ok(Transformed::jump(plan)); + } + + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } + + match plan { LogicalPlan::Union(Union { inputs, schema }, _) => { let inputs = inputs .into_iter() - .map(extract_plan_from_distinct) .flat_map(extract_plans_from_union) .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) .collect::>>()?; - Ok(Transformed::yes(LogicalPlan::distinct(Distinct::All( - Arc::new(LogicalPlan::union(Union { - inputs: inputs.into_iter().map(Arc::new).collect_vec(), - schema: Arc::clone(&schema), - })), - )))) + Ok(Transformed::yes(LogicalPlan::union(Union { + inputs: inputs.into_iter().map(Arc::new).collect_vec(), + schema, + }))) } - nested_plan => Ok(Transformed::no(LogicalPlan::distinct( - Distinct::All(Arc::new(nested_plan)), - ))), + LogicalPlan::Distinct(Distinct::All(nested_plan), _) => { + match Arc::unwrap_or_clone(nested_plan) { + LogicalPlan::Union(Union { inputs, schema }, _) => { + let inputs = inputs + .into_iter() + .map(extract_plan_from_distinct) + .flat_map(extract_plans_from_union) + .map(|plan| { + coerce_plan_expr_for_schema(plan, &schema) + }) + .collect::>>()?; + + Ok(Transformed::yes(LogicalPlan::distinct( + Distinct::All(Arc::new(LogicalPlan::union(Union { + inputs: inputs + .into_iter() + .map(Arc::new) + .collect_vec(), + schema: Arc::clone(&schema), + }))), + ))) + } + nested_plan => Ok(Transformed::no(LogicalPlan::distinct( + Distinct::All(Arc::new(nested_plan)), + ))), + } + } + _ => Ok(Transformed::no(plan)), } - } - _ => Ok(Transformed::no(plan)), - } + }, + ) } } diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index ac3da4e8f65d..f06a30f6c17f 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -19,11 +19,10 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{tree_node::Transformed, Result}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::logical_plan::{LogicalPlan, Union}; use std::sync::Arc; -use crate::optimizer::ApplyOrder; - #[derive(Default, Debug)] /// An optimization rule that eliminates union with one element. pub struct EliminateOneUnion; @@ -49,16 +48,23 @@ impl OptimizerRule for EliminateOneUnion { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Union(Union { mut inputs, .. }, _) if inputs.len() == 1 => Ok( - Transformed::yes(Arc::unwrap_or_clone(inputs.pop().unwrap())), - ), - _ => Ok(Transformed::no(plan)), - } - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) + plan.transform_down_with_subqueries(|plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanUnion) + { + return Ok(Transformed::jump(plan)); + } + + match plan { + LogicalPlan::Union(Union { mut inputs, .. }, _) if inputs.len() == 1 => { + Ok(Transformed::yes(Arc::unwrap_or_clone( + inputs.pop().unwrap(), + ))) + } + _ => Ok(Transformed::no(plan)), + } + }) } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index bca5d61d4c44..1891cf80a5be 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -21,9 +21,10 @@ use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan}; use datafusion_expr::{Expr, Filter, Operator}; -use crate::optimizer::ApplyOrder; use datafusion_common::tree_node::Transformed; use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; +use enumset::enum_set; use std::sync::Arc; /// @@ -64,10 +65,6 @@ impl OptimizerRule for EliminateOuterJoin { "eliminate_outer_join" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn supports_rewrite(&self) -> bool { true } @@ -77,61 +74,70 @@ impl OptimizerRule for EliminateOuterJoin { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Filter(mut filter, _) => { - match Arc::unwrap_or_clone(filter.input) { - LogicalPlan::Join(join, _) => { - let mut non_nullable_cols: Vec = vec![]; + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_all_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanFilter + | LogicalPlanPattern::LogicalPlanJoin + )) { + return Ok(Transformed::jump(plan)); + } - extract_non_nullable_columns( - &filter.predicate, - &mut non_nullable_cols, - join.left.schema(), - join.right.schema(), - true, - ); + match plan { + LogicalPlan::Filter(mut filter, _) => { + match Arc::unwrap_or_clone(filter.input) { + LogicalPlan::Join(join, _) => { + let mut non_nullable_cols: Vec = vec![]; - let new_join_type = if join.join_type.is_outer() { - let mut left_non_nullable = false; - let mut right_non_nullable = false; - for col in non_nullable_cols.iter() { - if join.left.schema().has_column(col) { - left_non_nullable = true; - } - if join.right.schema().has_column(col) { - right_non_nullable = true; + extract_non_nullable_columns( + &filter.predicate, + &mut non_nullable_cols, + join.left.schema(), + join.right.schema(), + true, + ); + + let new_join_type = if join.join_type.is_outer() { + let mut left_non_nullable = false; + let mut right_non_nullable = false; + for col in non_nullable_cols.iter() { + if join.left.schema().has_column(col) { + left_non_nullable = true; + } + if join.right.schema().has_column(col) { + right_non_nullable = true; + } } - } - eliminate_outer( - join.join_type, - left_non_nullable, - right_non_nullable, - ) - } else { - join.join_type - }; + eliminate_outer( + join.join_type, + left_non_nullable, + right_non_nullable, + ) + } else { + join.join_type + }; - let new_join = Arc::new(LogicalPlan::join(Join { - left: join.left, - right: join.right, - join_type: new_join_type, - join_constraint: join.join_constraint, - on: join.on.clone(), - filter: join.filter.clone(), - schema: Arc::clone(&join.schema), - null_equals_null: join.null_equals_null, - })); - Filter::try_new(filter.predicate, new_join) - .map(|f| Transformed::yes(LogicalPlan::filter(f))) - } - filter_input => { - filter.input = Arc::new(filter_input); - Ok(Transformed::no(LogicalPlan::filter(filter))) + let new_join = Arc::new(LogicalPlan::join(Join { + left: join.left, + right: join.right, + join_type: new_join_type, + join_constraint: join.join_constraint, + on: join.on.clone(), + filter: join.filter.clone(), + schema: Arc::clone(&join.schema), + null_equals_null: join.null_equals_null, + })); + Filter::try_new(filter.predicate, new_join) + .map(|f| Transformed::yes(LogicalPlan::filter(f))) + } + filter_input => { + filter.input = Arc::new(filter_input); + Ok(Transformed::no(LogicalPlan::filter(filter))) + } } } + _ => Ok(Transformed::no(plan)), } - _ => Ok(Transformed::no(plan)), - } + }) } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 16c3355c3b8f..ce82ee4d98ff 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -16,14 +16,17 @@ // under the License. //! [`ExtractEquijoinPredicate`] identifies equality join (equijoin) predicates -use crate::optimizer::ApplyOrder; + use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::DFSchema; use datafusion_common::Result; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::utils::split_conjunction_owned; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; +use std::cell::Cell; + // equijoin predicate type EquijoinPredicate = (Expr, Expr); @@ -57,61 +60,82 @@ impl OptimizerRule for ExtractEquijoinPredicate { "extract_equijoin_predicate" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } - fn rewrite( &self, plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Join( - Join { - left, - right, - mut on, - filter: Some(expr), - join_type, - join_constraint, - schema, - null_equals_null, - }, - _, - ) => { - let left_schema = left.schema(); - let right_schema = right.schema(); - let (equijoin_predicates, non_equijoin_expr) = - split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?; - - if !equijoin_predicates.is_empty() { - on.extend(equijoin_predicates); - Ok(Transformed::yes(LogicalPlan::join(Join { - left, - right, - on, - filter: non_equijoin_expr, - join_type, - join_constraint, - schema, - null_equals_null, - }))) - } else { - Ok(Transformed::no(LogicalPlan::join(Join { - left, - right, - on, - filter: non_equijoin_expr, - join_type, - join_constraint, - schema, - null_equals_null, - }))) + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanJoin) + { + skip.set(true); + return Ok(Transformed::jump(plan)); } - } - _ => Ok(Transformed::no(plan)), - } + + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } + + match plan { + LogicalPlan::Join( + Join { + left, + right, + mut on, + filter: Some(expr), + join_type, + join_constraint, + schema, + null_equals_null, + }, + _, + ) => { + let left_schema = left.schema(); + let right_schema = right.schema(); + let (equijoin_predicates, non_equijoin_expr) = + split_eq_and_noneq_join_predicate( + expr, + left_schema, + right_schema, + )?; + + if !equijoin_predicates.is_empty() { + on.extend(equijoin_predicates); + Ok(Transformed::yes(LogicalPlan::join(Join { + left, + right, + on, + filter: non_equijoin_expr, + join_type, + join_constraint, + schema, + null_equals_null, + }))) + } else { + Ok(Transformed::no(LogicalPlan::join(Join { + left, + right, + on, + filter: non_equijoin_expr, + join_type, + join_constraint, + schema, + null_equals_null, + }))) + } + } + _ => Ok(Transformed::no(plan)), + } + }, + ) } } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 3f190ff32667..6b935e16ca8c 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -17,13 +17,14 @@ //! [`FilterNullJoinKeys`] adds filters to join inputs when input isn't nullable -use crate::optimizer::ApplyOrder; use crate::push_down_filter::on_lr_is_preserved; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::utils::conjunction; use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan}; +use std::cell::Cell; use std::sync::Arc; /// The FilterNullJoinKeys rule will identify joins with equi-join conditions @@ -37,10 +38,6 @@ impl OptimizerRule for FilterNullJoinKeys { true } - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } - fn rewrite( &self, plan: LogicalPlan, @@ -49,45 +46,67 @@ impl OptimizerRule for FilterNullJoinKeys { if !config.options().optimizer.filter_null_join_keys { return Ok(Transformed::no(plan)); } - match plan { - LogicalPlan::Join(mut join, _) - if !join.on.is_empty() && !join.null_equals_null => - { - let (left_preserved, right_preserved) = - on_lr_is_preserved(join.join_type); - let left_schema = join.left.schema(); - let right_schema = join.right.schema(); + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanJoin) + { + skip.set(true); + return Ok(Transformed::jump(plan)); + } - let mut left_filters = vec![]; - let mut right_filters = vec![]; + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } - for (l, r) in &join.on { - if left_preserved && l.nullable(left_schema)? { - left_filters.push(l.clone()); - } + match plan { + LogicalPlan::Join(mut join, _) + if !join.on.is_empty() && !join.null_equals_null => + { + let (left_preserved, right_preserved) = + on_lr_is_preserved(join.join_type); - if right_preserved && r.nullable(right_schema)? { - right_filters.push(r.clone()); - } - } + let left_schema = join.left.schema(); + let right_schema = join.right.schema(); - if !left_filters.is_empty() { - let predicate = create_not_null_predicate(left_filters); - join.left = Arc::new(LogicalPlan::filter(Filter::try_new( - predicate, join.left, - )?)); - } - if !right_filters.is_empty() { - let predicate = create_not_null_predicate(right_filters); - join.right = Arc::new(LogicalPlan::filter(Filter::try_new( - predicate, join.right, - )?)); + let mut left_filters = vec![]; + let mut right_filters = vec![]; + + for (l, r) in &join.on { + if left_preserved && l.nullable(left_schema)? { + left_filters.push(l.clone()); + } + + if right_preserved && r.nullable(right_schema)? { + right_filters.push(r.clone()); + } + } + + if !left_filters.is_empty() { + let predicate = create_not_null_predicate(left_filters); + join.left = Arc::new(LogicalPlan::filter(Filter::try_new( + predicate, join.left, + )?)); + } + if !right_filters.is_empty() { + let predicate = create_not_null_predicate(right_filters); + join.right = Arc::new(LogicalPlan::filter(Filter::try_new( + predicate, join.right, + )?)); + } + Ok(Transformed::yes(LogicalPlan::join(join))) + } + _ => Ok(Transformed::no(plan)), } - Ok(Transformed::yes(LogicalPlan::join(join))) - } - _ => Ok(Transformed::no(plan)), - } + }, + ) } fn name(&self) -> &str { "filter_null_join_keys" diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 8f38970d1dc4..e57da0d362f5 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -17,16 +17,16 @@ //! [`PropagateEmptyRelation`] eliminates nodes fed by `EmptyRelation` -use std::sync::Arc; - +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::JoinType; use datafusion_common::{plan_err, Result}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{EmptyRelation, Projection, Union}; - -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; +use enumset::enum_set; +use std::cell::Cell; +use std::sync::Arc; /// Optimization rule that bottom-up to eliminate plan by propagating empty_relation. #[derive(Default, Debug)] @@ -44,10 +44,6 @@ impl OptimizerRule for PropagateEmptyRelation { "propagate_empty_relation" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } - fn supports_rewrite(&self) -> bool { true } @@ -57,140 +53,183 @@ impl OptimizerRule for PropagateEmptyRelation { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::EmptyRelation(_, _) => Ok(Transformed::no(plan)), - LogicalPlan::Projection(_, _) - | LogicalPlan::Filter(_, _) - | LogicalPlan::Window(_, _) - | LogicalPlan::Sort(_, _) - | LogicalPlan::SubqueryAlias(_, _) - | LogicalPlan::Repartition(_, _) - | LogicalPlan::Limit(_, _) => { - let empty = empty_child(&plan)?; - if let Some(empty_plan) = empty { - return Ok(Transformed::yes(empty_plan)); + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !(plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanProjection + | LogicalPlanPattern::LogicalPlanFilter + | LogicalPlanPattern::LogicalPlanWindow + | LogicalPlanPattern::LogicalPlanSort + | LogicalPlanPattern::LogicalPlanSubqueryAlias + | LogicalPlanPattern::LogicalPlanRepartition + | LogicalPlanPattern::LogicalPlanLimit + | LogicalPlanPattern::LogicalPlanJoin + | LogicalPlanPattern::LogicalPlanAggregate + | LogicalPlanPattern::LogicalPlanUnion + )) && plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanEmptyRelation)) + { + skip.set(true); + return Ok(Transformed::jump(plan)); } + Ok(Transformed::no(plan)) - } - LogicalPlan::Join(ref join, _) => { - // TODO: For Join, more join type need to be careful: - // For LeftOut/Full Join, if the right side is empty, the Join can be eliminated with a Projection with left side - // columns + right side columns replaced with null values. - // For RightOut/Full Join, if the left side is empty, the Join can be eliminated with a Projection with right side - // columns + left side columns replaced with null values. - let (left_empty, right_empty) = binary_plan_children_is_empty(&plan)?; - - match join.join_type { - // For Full Join, only both sides are empty, the Join result is empty. - JoinType::Full if left_empty && right_empty => Ok(Transformed::yes( - LogicalPlan::empty_relation(EmptyRelation { - produce_one_row: false, - schema: Arc::clone(&join.schema), - }), - )), - JoinType::Inner if left_empty || right_empty => Ok(Transformed::yes( - LogicalPlan::empty_relation(EmptyRelation { - produce_one_row: false, - schema: Arc::clone(&join.schema), - }), - )), - JoinType::Left if left_empty => Ok(Transformed::yes( - LogicalPlan::empty_relation(EmptyRelation { - produce_one_row: false, - schema: Arc::clone(&join.schema), - }), - )), - JoinType::Right if right_empty => Ok(Transformed::yes( - LogicalPlan::empty_relation(EmptyRelation { - produce_one_row: false, - schema: Arc::clone(&join.schema), - }), - )), - JoinType::LeftSemi if left_empty || right_empty => Ok( - Transformed::yes(LogicalPlan::empty_relation(EmptyRelation { - produce_one_row: false, - schema: Arc::clone(&join.schema), - })), - ), - JoinType::RightSemi if left_empty || right_empty => Ok( - Transformed::yes(LogicalPlan::empty_relation(EmptyRelation { - produce_one_row: false, - schema: Arc::clone(&join.schema), - })), - ), - JoinType::LeftAnti if left_empty => Ok(Transformed::yes( - LogicalPlan::empty_relation(EmptyRelation { - produce_one_row: false, - schema: Arc::clone(&join.schema), - }), - )), - JoinType::LeftAnti if right_empty => { - Ok(Transformed::yes((*join.left).clone())) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } + + match plan { + LogicalPlan::EmptyRelation(_, _) => Ok(Transformed::no(plan)), + LogicalPlan::Projection(_, _) + | LogicalPlan::Filter(_, _) + | LogicalPlan::Window(_, _) + | LogicalPlan::Sort(_, _) + | LogicalPlan::SubqueryAlias(_, _) + | LogicalPlan::Repartition(_, _) + | LogicalPlan::Limit(_, _) => { + let empty = empty_child(&plan)?; + if let Some(empty_plan) = empty { + return Ok(Transformed::yes(empty_plan)); + } + Ok(Transformed::no(plan)) } - JoinType::RightAnti if left_empty => { - Ok(Transformed::yes((*join.right).clone())) + LogicalPlan::Join(ref join, _) => { + // TODO: For Join, more join type need to be careful: + // For LeftOut/Full Join, if the right side is empty, the Join can be eliminated with a Projection with left side + // columns + right side columns replaced with null values. + // For RightOut/Full Join, if the left side is empty, the Join can be eliminated with a Projection with right side + // columns + left side columns replaced with null values. + let (left_empty, right_empty) = + binary_plan_children_is_empty(&plan)?; + + match join.join_type { + // For Full Join, only both sides are empty, the Join result is empty. + JoinType::Full if left_empty && right_empty => { + Ok(Transformed::yes(LogicalPlan::empty_relation( + EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }, + ))) + } + JoinType::Inner if left_empty || right_empty => { + Ok(Transformed::yes(LogicalPlan::empty_relation( + EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }, + ))) + } + JoinType::Left if left_empty => Ok(Transformed::yes( + LogicalPlan::empty_relation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }), + )), + JoinType::Right if right_empty => Ok(Transformed::yes( + LogicalPlan::empty_relation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }), + )), + JoinType::LeftSemi if left_empty || right_empty => { + Ok(Transformed::yes(LogicalPlan::empty_relation( + EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }, + ))) + } + JoinType::RightSemi if left_empty || right_empty => { + Ok(Transformed::yes(LogicalPlan::empty_relation( + EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }, + ))) + } + JoinType::LeftAnti if left_empty => Ok(Transformed::yes( + LogicalPlan::empty_relation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }), + )), + JoinType::LeftAnti if right_empty => { + Ok(Transformed::yes((*join.left).clone())) + } + JoinType::RightAnti if left_empty => { + Ok(Transformed::yes((*join.right).clone())) + } + JoinType::RightAnti if right_empty => Ok(Transformed::yes( + LogicalPlan::empty_relation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }), + )), + _ => Ok(Transformed::no(plan)), + } } - JoinType::RightAnti if right_empty => Ok(Transformed::yes( - LogicalPlan::empty_relation(EmptyRelation { - produce_one_row: false, - schema: Arc::clone(&join.schema), - }), - )), - _ => Ok(Transformed::no(plan)), - } - } - LogicalPlan::Aggregate(ref agg, _) => { - if !agg.group_expr.is_empty() { - if let Some(empty_plan) = empty_child(&plan)? { - return Ok(Transformed::yes(empty_plan)); + LogicalPlan::Aggregate(ref agg, _) => { + if !agg.group_expr.is_empty() { + if let Some(empty_plan) = empty_child(&plan)? { + return Ok(Transformed::yes(empty_plan)); + } + } + Ok(Transformed::no(LogicalPlan::aggregate(agg.clone()))) } - } - Ok(Transformed::no(LogicalPlan::aggregate(agg.clone()))) - } - LogicalPlan::Union(ref union, _) => { - let new_inputs = union - .inputs - .iter() - .filter(|input| match &***input { - LogicalPlan::EmptyRelation(empty, _) => empty.produce_one_row, - _ => true, - }) - .cloned() - .collect::>(); - - if new_inputs.len() == union.inputs.len() { - Ok(Transformed::no(plan)) - } else if new_inputs.is_empty() { - Ok(Transformed::yes(LogicalPlan::empty_relation( - EmptyRelation { - produce_one_row: false, - schema: Arc::clone(plan.schema()), - }, - ))) - } else if new_inputs.len() == 1 { - let mut new_inputs = new_inputs; - let input_plan = new_inputs.pop().unwrap(); // length checked - let child = Arc::unwrap_or_clone(input_plan); - if child.schema().eq(plan.schema()) { - Ok(Transformed::yes(child)) - } else { - Ok(Transformed::yes(LogicalPlan::projection( - Projection::new_from_schema( - Arc::new(child), - Arc::clone(plan.schema()), - ), - ))) + LogicalPlan::Union(ref union, _) => { + let new_inputs = union + .inputs + .iter() + .filter(|input| match &***input { + LogicalPlan::EmptyRelation(empty, _) => { + empty.produce_one_row + } + _ => true, + }) + .cloned() + .collect::>(); + + if new_inputs.len() == union.inputs.len() { + Ok(Transformed::no(plan)) + } else if new_inputs.is_empty() { + Ok(Transformed::yes(LogicalPlan::empty_relation( + EmptyRelation { + produce_one_row: false, + schema: Arc::clone(plan.schema()), + }, + ))) + } else if new_inputs.len() == 1 { + let mut new_inputs = new_inputs; + let input_plan = new_inputs.pop().unwrap(); // length checked + let child = Arc::unwrap_or_clone(input_plan); + if child.schema().eq(plan.schema()) { + Ok(Transformed::yes(child)) + } else { + Ok(Transformed::yes(LogicalPlan::projection( + Projection::new_from_schema( + Arc::new(child), + Arc::clone(plan.schema()), + ), + ))) + } + } else { + Ok(Transformed::yes(LogicalPlan::union(Union { + inputs: new_inputs, + schema: Arc::clone(&union.schema), + }))) + } } - } else { - Ok(Transformed::yes(LogicalPlan::union(Union { - inputs: new_inputs, - schema: Arc::clone(&union.schema), - }))) - } - } - _ => Ok(Transformed::no(plan)), - } + _ => Ok(Transformed::no(plan)), + } + }, + ) } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 7ccc4c9e7857..e592da73e5d7 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -17,12 +17,14 @@ //! [`PushDownFilter`] applies filters as early as possible +use enumset::enum_set; use indexmap::IndexSet; +use itertools::Itertools; use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use itertools::Itertools; - +use crate::utils::{has_all_column_refs, is_restrict_null_predicate}; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -30,6 +32,7 @@ use datafusion_common::{ internal_err, plan_err, qualified_name, Column, DFSchema, Result, }; use datafusion_expr::expr_rewriter::replace_col; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union}; use datafusion_expr::utils::{ conjunction, expr_to_columns, split_conjunction, split_conjunction_owned, @@ -38,10 +41,6 @@ use datafusion_expr::{ and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, }; -use crate::optimizer::ApplyOrder; -use crate::utils::{has_all_column_refs, is_restrict_null_predicate}; -use crate::{OptimizerConfig, OptimizerRule}; - /// Optimizer rule for pushing (moving) filter expressions down in a plan so /// they are applied as early as possible. /// @@ -761,10 +760,6 @@ impl OptimizerRule for PushDownFilter { "push_down_filter" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn supports_rewrite(&self) -> bool { true } @@ -774,356 +769,365 @@ impl OptimizerRule for PushDownFilter { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - if let LogicalPlan::Join(join, _) = plan { - return push_down_join(join, None); - }; - - let plan_schema = Arc::clone(plan.schema()); - - let LogicalPlan::Filter(mut filter, _) = plan else { - return Ok(Transformed::no(plan)); - }; - - match Arc::unwrap_or_clone(filter.input) { - LogicalPlan::Filter(child_filter, _) => { - let parents_predicates = split_conjunction_owned(filter.predicate); - - // remove duplicated filters - let child_predicates = split_conjunction_owned(child_filter.predicate); - let new_predicates = parents_predicates - .into_iter() - .chain(child_predicates) - // use IndexSet to remove dupes while preserving predicate order - .collect::>() - .into_iter() - .collect::>(); - - let Some(new_predicate) = conjunction(new_predicates) else { - return plan_err!("at least one expression exists"); - }; - let new_filter = LogicalPlan::filter(Filter::try_new( - new_predicate, - child_filter.input, - )?); - self.rewrite(new_filter, _config) - } - LogicalPlan::Repartition(repartition, _) => { - let new_filter = - Filter::try_new(filter.predicate, Arc::clone(&repartition.input)) - .map(LogicalPlan::filter)?; - insert_below(LogicalPlan::repartition(repartition), new_filter) - } - LogicalPlan::Distinct(distinct, _) => { - let new_filter = - Filter::try_new(filter.predicate, Arc::clone(distinct.input())) - .map(LogicalPlan::filter)?; - insert_below(LogicalPlan::distinct(distinct), new_filter) - } - LogicalPlan::Sort(sort, _) => { - let new_filter = - Filter::try_new(filter.predicate, Arc::clone(&sort.input)) - .map(LogicalPlan::filter)?; - insert_below(LogicalPlan::sort(sort), new_filter) + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanJoin + | LogicalPlanPattern::LogicalPlanFilter + )) { + return Ok(Transformed::jump(plan)); } - LogicalPlan::SubqueryAlias(subquery_alias, _) => { - let mut replace_map = HashMap::new(); - for (i, (qualifier, field)) in - subquery_alias.input.schema().iter().enumerate() - { - let (sub_qualifier, sub_field) = - subquery_alias.schema.qualified_field(i); - replace_map.insert( - qualified_name(sub_qualifier, sub_field.name()), - Expr::column(Column::new(qualifier.cloned(), field.name())), - ); - } - let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?; - let new_filter = LogicalPlan::filter(Filter::try_new( - new_predicate, - Arc::clone(&subquery_alias.input), - )?); - insert_below(LogicalPlan::subquery_alias(subquery_alias), new_filter) - } - LogicalPlan::Projection(projection, _) => { - let predicates = split_conjunction_owned(filter.predicate.clone()); - let (new_projection, keep_predicate) = - rewrite_projection(predicates, projection)?; - if new_projection.transformed { - match keep_predicate { - None => Ok(new_projection), - Some(keep_predicate) => new_projection.map_data(|child_plan| { - Filter::try_new(keep_predicate, Arc::new(child_plan)) - .map(LogicalPlan::filter) - }), - } - } else { - filter.input = Arc::new(new_projection.data); - Ok(Transformed::no(LogicalPlan::filter(filter))) + if let LogicalPlan::Join(join, _) = plan { + return push_down_join(join, None); + }; + + let plan_schema = Arc::clone(plan.schema()); + + let LogicalPlan::Filter(mut filter, _) = plan else { + return Ok(Transformed::no(plan)); + }; + + match Arc::unwrap_or_clone(filter.input) { + LogicalPlan::Filter(child_filter, _) => { + let parents_predicates = split_conjunction_owned(filter.predicate); + + // remove duplicated filters + let child_predicates = split_conjunction_owned(child_filter.predicate); + let new_predicates = parents_predicates + .into_iter() + .chain(child_predicates) + // use IndexSet to remove dupes while preserving predicate order + .collect::>() + .into_iter() + .collect::>(); + + let Some(new_predicate) = conjunction(new_predicates) else { + return plan_err!("at least one expression exists"); + }; + let new_filter = LogicalPlan::filter(Filter::try_new( + new_predicate, + child_filter.input, + )?); + self.rewrite(new_filter, _config) } - } - LogicalPlan::Unnest(mut unnest, _) => { - let predicates = split_conjunction_owned(filter.predicate.clone()); - let mut non_unnest_predicates = vec![]; - let mut unnest_predicates = vec![]; - for predicate in predicates { - // collect all the Expr::Column in predicate recursively - let mut accum: HashSet = HashSet::new(); - expr_to_columns(&predicate, &mut accum)?; - - if unnest.list_type_columns.iter().any(|(_, unnest_list)| { - accum.contains(&unnest_list.output_column) - }) { - unnest_predicates.push(predicate); - } else { - non_unnest_predicates.push(predicate); - } + LogicalPlan::Repartition(repartition, _) => { + let new_filter = + Filter::try_new(filter.predicate, Arc::clone(&repartition.input)) + .map(LogicalPlan::filter)?; + insert_below(LogicalPlan::repartition(repartition), new_filter) } - - // Unnest predicates should not be pushed down. - // If no non-unnest predicates exist, early return - if non_unnest_predicates.is_empty() { - filter.input = Arc::new(LogicalPlan::unnest(unnest)); - return Ok(Transformed::no(LogicalPlan::filter(filter))); + LogicalPlan::Distinct(distinct, _) => { + let new_filter = + Filter::try_new(filter.predicate, Arc::clone(distinct.input())) + .map(LogicalPlan::filter)?; + insert_below(LogicalPlan::distinct(distinct), new_filter) } - - // Push down non-unnest filter predicate - // Unnest - // Unnest Input (Projection) - // -> rewritten to - // Unnest - // Filter - // Unnest Input (Projection) - - let unnest_input = std::mem::take(&mut unnest.input); - - let filter_with_unnest_input = LogicalPlan::filter(Filter::try_new( - conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty. - unnest_input, - )?); - - // Directly assign new filter plan as the new unnest's input. - // The new filter plan will go through another rewrite pass since the rule itself - // is applied recursively to all the child from top to down - let unnest_plan = - insert_below(LogicalPlan::unnest(unnest), filter_with_unnest_input)?; - - match conjunction(unnest_predicates) { - None => Ok(unnest_plan), - Some(predicate) => Ok(Transformed::yes(LogicalPlan::filter( - Filter::try_new(predicate, Arc::new(unnest_plan.data))?, - ))), + LogicalPlan::Sort(sort, _) => { + let new_filter = + Filter::try_new(filter.predicate, Arc::clone(&sort.input)) + .map(LogicalPlan::filter)?; + insert_below(LogicalPlan::sort(sort), new_filter) } - } - LogicalPlan::Union(ref union, _) => { - let mut inputs = Vec::with_capacity(union.inputs.len()); - for input in &union.inputs { + LogicalPlan::SubqueryAlias(subquery_alias, _) => { let mut replace_map = HashMap::new(); - for (i, (qualifier, field)) in input.schema().iter().enumerate() { - let (union_qualifier, union_field) = - union.schema.qualified_field(i); + for (i, (qualifier, field)) in + subquery_alias.input.schema().iter().enumerate() + { + let (sub_qualifier, sub_field) = + subquery_alias.schema.qualified_field(i); replace_map.insert( - qualified_name(union_qualifier, union_field.name()), + qualified_name(sub_qualifier, sub_field.name()), Expr::column(Column::new(qualifier.cloned(), field.name())), ); } + let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?; - let push_predicate = - replace_cols_by_name(filter.predicate.clone(), &replace_map)?; - inputs.push(Arc::new(LogicalPlan::filter(Filter::try_new( - push_predicate, - Arc::clone(input), - )?))) + let new_filter = LogicalPlan::filter(Filter::try_new( + new_predicate, + Arc::clone(&subquery_alias.input), + )?); + insert_below(LogicalPlan::subquery_alias(subquery_alias), new_filter) } - Ok(Transformed::yes(LogicalPlan::union(Union { - inputs, - schema: Arc::clone(&plan_schema), - }))) - } - LogicalPlan::Aggregate(agg, _) => { - // We can push down Predicate which in groupby_expr. - let group_expr_columns = agg - .group_expr - .iter() - .map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string()))) - .collect::>>()?; - - let predicates = split_conjunction_owned(filter.predicate); - - let mut keep_predicates = vec![]; - let mut push_predicates = vec![]; - for expr in predicates { - let cols = expr.column_refs(); - if cols.iter().all(|c| group_expr_columns.contains(c)) { - push_predicates.push(expr); + LogicalPlan::Projection(projection, _) => { + let predicates = split_conjunction_owned(filter.predicate.clone()); + let (new_projection, keep_predicate) = + rewrite_projection(predicates, projection)?; + if new_projection.transformed { + match keep_predicate { + None => Ok(new_projection), + Some(keep_predicate) => new_projection.map_data(|child_plan| { + Filter::try_new(keep_predicate, Arc::new(child_plan)) + .map(LogicalPlan::filter) + }), + } } else { - keep_predicates.push(expr); + filter.input = Arc::new(new_projection.data); + Ok(Transformed::no(LogicalPlan::filter(filter))) } } + LogicalPlan::Unnest(mut unnest, _) => { + let predicates = split_conjunction_owned(filter.predicate.clone()); + let mut non_unnest_predicates = vec![]; + let mut unnest_predicates = vec![]; + for predicate in predicates { + // collect all the Expr::Column in predicate recursively + let mut accum: HashSet = HashSet::new(); + expr_to_columns(&predicate, &mut accum)?; + + if unnest.list_type_columns.iter().any(|(_, unnest_list)| { + accum.contains(&unnest_list.output_column) + }) { + unnest_predicates.push(predicate); + } else { + non_unnest_predicates.push(predicate); + } + } - // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)] - // After push, we need to replace `a+b` with Column(a)+Column(b) - // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))} - let mut replace_map = HashMap::new(); - for expr in &agg.group_expr { - replace_map.insert(expr.schema_name().to_string(), expr.clone()); + // Unnest predicates should not be pushed down. + // If no non-unnest predicates exist, early return + if non_unnest_predicates.is_empty() { + filter.input = Arc::new(LogicalPlan::unnest(unnest)); + return Ok(Transformed::no(LogicalPlan::filter(filter))); + } + + // Push down non-unnest filter predicate + // Unnest + // Unnest Input (Projection) + // -> rewritten to + // Unnest + // Filter + // Unnest Input (Projection) + + let unnest_input = std::mem::take(&mut unnest.input); + + let filter_with_unnest_input = LogicalPlan::filter(Filter::try_new( + conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty. + unnest_input, + )?); + + // Directly assign new filter plan as the new unnest's input. + // The new filter plan will go through another rewrite pass since the rule itself + // is applied recursively to all the child from top to down + let unnest_plan = + insert_below(LogicalPlan::unnest(unnest), filter_with_unnest_input)?; + + match conjunction(unnest_predicates) { + None => Ok(unnest_plan), + Some(predicate) => Ok(Transformed::yes(LogicalPlan::filter( + Filter::try_new(predicate, Arc::new(unnest_plan.data))?, + ))), + } } - let replaced_push_predicates = push_predicates - .into_iter() - .map(|expr| replace_cols_by_name(expr, &replace_map)) - .collect::>>()?; - - let agg_input = Arc::clone(&agg.input); - Transformed::yes(LogicalPlan::aggregate(agg)) - .transform_data(|new_plan| { - // If we have a filter to push, we push it down to the input of the aggregate - if let Some(predicate) = conjunction(replaced_push_predicates) { - let new_filter = make_filter(predicate, agg_input)?; - insert_below(new_plan, new_filter) - } else { - Ok(Transformed::no(new_plan)) + LogicalPlan::Union(ref union, _) => { + let mut inputs = Vec::with_capacity(union.inputs.len()); + for input in &union.inputs { + let mut replace_map = HashMap::new(); + for (i, (qualifier, field)) in input.schema().iter().enumerate() { + let (union_qualifier, union_field) = + union.schema.qualified_field(i); + replace_map.insert( + qualified_name(union_qualifier, union_field.name()), + Expr::column(Column::new(qualifier.cloned(), field.name())), + ); } - })? - .map_data(|child_plan| { - // if there are any remaining predicates we can't push, add them - // back as a filter - if let Some(predicate) = conjunction(keep_predicates) { - make_filter(predicate, Arc::new(child_plan)) + + let push_predicate = + replace_cols_by_name(filter.predicate.clone(), &replace_map)?; + inputs.push(Arc::new(LogicalPlan::filter(Filter::try_new( + push_predicate, + Arc::clone(input), + )?))) + } + Ok(Transformed::yes(LogicalPlan::union(Union { + inputs, + schema: Arc::clone(&plan_schema), + }))) + } + LogicalPlan::Aggregate(agg, _) => { + // We can push down Predicate which in groupby_expr. + let group_expr_columns = agg + .group_expr + .iter() + .map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string()))) + .collect::>>()?; + + let predicates = split_conjunction_owned(filter.predicate); + + let mut keep_predicates = vec![]; + let mut push_predicates = vec![]; + for expr in predicates { + let cols = expr.column_refs(); + if cols.iter().all(|c| group_expr_columns.contains(c)) { + push_predicates.push(expr); } else { - Ok(child_plan) + keep_predicates.push(expr); } - }) - } - LogicalPlan::Join(join, _) => push_down_join(join, Some(&filter.predicate)), - LogicalPlan::TableScan(scan, _) => { - let filter_predicates = split_conjunction(&filter.predicate); + } - let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = - filter_predicates + // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)] + // After push, we need to replace `a+b` with Column(a)+Column(b) + // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))} + let mut replace_map = HashMap::new(); + for expr in &agg.group_expr { + replace_map.insert(expr.schema_name().to_string(), expr.clone()); + } + let replaced_push_predicates = push_predicates .into_iter() - .partition(|pred| pred.is_volatile()); - - // Check which non-volatile filters are supported by source - let supported_filters = scan - .source - .supports_filters_pushdown(non_volatile_filters.as_slice())?; - if non_volatile_filters.len() != supported_filters.len() { - return internal_err!( - "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}", - supported_filters.len(), - non_volatile_filters.len()); + .map(|expr| replace_cols_by_name(expr, &replace_map)) + .collect::>>()?; + + let agg_input = Arc::clone(&agg.input); + Transformed::yes(LogicalPlan::aggregate(agg)) + .transform_data(|new_plan| { + // If we have a filter to push, we push it down to the input of the aggregate + if let Some(predicate) = conjunction(replaced_push_predicates) { + let new_filter = make_filter(predicate, agg_input)?; + insert_below(new_plan, new_filter) + } else { + Ok(Transformed::no(new_plan)) + } + })? + .map_data(|child_plan| { + // if there are any remaining predicates we can't push, add them + // back as a filter + if let Some(predicate) = conjunction(keep_predicates) { + make_filter(predicate, Arc::new(child_plan)) + } else { + Ok(child_plan) + } + }) } - - // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type - let zip = non_volatile_filters.into_iter().zip(supported_filters); - - let new_scan_filters = zip - .clone() - .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported) - .map(|(pred, _)| pred); - - // Add new scan filters - let new_scan_filters: Vec = scan - .filters - .iter() - .chain(new_scan_filters) - .unique() - .cloned() - .collect(); - - // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters - let new_predicate: Vec = zip - .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact) - .map(|(pred, _)| pred) - .chain(volatile_filters) - .cloned() - .collect(); - - let new_scan = LogicalPlan::table_scan(TableScan { - filters: new_scan_filters, - ..scan - }); - - Transformed::yes(new_scan).transform_data(|new_scan| { - if let Some(predicate) = conjunction(new_predicate) { - make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes) - } else { - Ok(Transformed::no(new_scan)) + LogicalPlan::Join(join, _) => push_down_join(join, Some(&filter.predicate)), + LogicalPlan::TableScan(scan, _) => { + let filter_predicates = split_conjunction(&filter.predicate); + + let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = + filter_predicates + .into_iter() + .partition(|pred| pred.is_volatile()); + + // Check which non-volatile filters are supported by source + let supported_filters = scan + .source + .supports_filters_pushdown(non_volatile_filters.as_slice())?; + if non_volatile_filters.len() != supported_filters.len() { + return internal_err!( + "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}", + supported_filters.len(), + non_volatile_filters.len()); } - }) - } - LogicalPlan::Extension(extension_plan, _) => { - let prevent_cols = - extension_plan.node.prevent_predicate_push_down_columns(); - // determine if we can push any predicates down past the extension node - - // each element is true for push, false to keep - let predicate_push_or_keep = split_conjunction(&filter.predicate) - .iter() - .map(|expr| { - let cols = expr.column_refs(); - if cols.iter().any(|c| prevent_cols.contains(&c.name)) { - Ok(false) // No push (keep) + // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type + let zip = non_volatile_filters.into_iter().zip(supported_filters); + + let new_scan_filters = zip + .clone() + .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported) + .map(|(pred, _)| pred); + + // Add new scan filters + let new_scan_filters: Vec = scan + .filters + .iter() + .chain(new_scan_filters) + .unique() + .cloned() + .collect(); + + // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters + let new_predicate: Vec = zip + .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact) + .map(|(pred, _)| pred) + .chain(volatile_filters) + .cloned() + .collect(); + + let new_scan = LogicalPlan::table_scan(TableScan { + filters: new_scan_filters, + ..scan + }); + + Transformed::yes(new_scan).transform_data(|new_scan| { + if let Some(predicate) = conjunction(new_predicate) { + make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes) } else { - Ok(true) // push + Ok(Transformed::no(new_scan)) } }) - .collect::>>()?; - - // all predicates are kept, no changes needed - if predicate_push_or_keep.iter().all(|&x| !x) { - filter.input = Arc::new(LogicalPlan::extension(extension_plan)); - return Ok(Transformed::no(LogicalPlan::filter(filter))); } + LogicalPlan::Extension(extension_plan, _) => { + let prevent_cols = + extension_plan.node.prevent_predicate_push_down_columns(); + + // determine if we can push any predicates down past the extension node + + // each element is true for push, false to keep + let predicate_push_or_keep = split_conjunction(&filter.predicate) + .iter() + .map(|expr| { + let cols = expr.column_refs(); + if cols.iter().any(|c| prevent_cols.contains(&c.name)) { + Ok(false) // No push (keep) + } else { + Ok(true) // push + } + }) + .collect::>>()?; - // going to push some predicates down, so split the predicates - let mut keep_predicates = vec![]; - let mut push_predicates = vec![]; - for (push, expr) in predicate_push_or_keep - .into_iter() - .zip(split_conjunction_owned(filter.predicate).into_iter()) - { - if !push { - keep_predicates.push(expr); - } else { - push_predicates.push(expr); + // all predicates are kept, no changes needed + if predicate_push_or_keep.iter().all(|&x| !x) { + filter.input = Arc::new(LogicalPlan::extension(extension_plan)); + return Ok(Transformed::no(LogicalPlan::filter(filter))); } - } - let new_children = match conjunction(push_predicates) { - Some(predicate) => extension_plan - .node - .inputs() + // going to push some predicates down, so split the predicates + let mut keep_predicates = vec![]; + let mut push_predicates = vec![]; + for (push, expr) in predicate_push_or_keep .into_iter() - .map(|child| { - Ok(LogicalPlan::filter(Filter::try_new( - predicate.clone(), - Arc::new(child.clone()), - )?)) - }) - .collect::>>()?, - None => extension_plan.node.inputs().into_iter().cloned().collect(), - }; - // extension with new inputs. - let child_plan = LogicalPlan::extension(extension_plan); - let new_extension = - child_plan.with_new_exprs(child_plan.expressions(), new_children)?; - - let new_plan = match conjunction(keep_predicates) { - Some(predicate) => LogicalPlan::filter(Filter::try_new( - predicate, - Arc::new(new_extension), - )?), - None => new_extension, - }; - Ok(Transformed::yes(new_plan)) - } - child => { - filter.input = Arc::new(child); - Ok(Transformed::no(LogicalPlan::filter(filter))) + .zip(split_conjunction_owned(filter.predicate).into_iter()) + { + if !push { + keep_predicates.push(expr); + } else { + push_predicates.push(expr); + } + } + + let new_children = match conjunction(push_predicates) { + Some(predicate) => extension_plan + .node + .inputs() + .into_iter() + .map(|child| { + Ok(LogicalPlan::filter(Filter::try_new( + predicate.clone(), + Arc::new(child.clone()), + )?)) + }) + .collect::>>()?, + None => extension_plan.node.inputs().into_iter().cloned().collect(), + }; + // extension with new inputs. + let child_plan = LogicalPlan::extension(extension_plan); + let new_extension = + child_plan.with_new_exprs(child_plan.expressions(), new_children)?; + + let new_plan = match conjunction(keep_predicates) { + Some(predicate) => LogicalPlan::filter(Filter::try_new( + predicate, + Arc::new(new_extension), + )?), + None => new_extension, + }; + Ok(Transformed::yes(new_plan)) + } + child => { + filter.input = Arc::new(child); + Ok(Transformed::no(LogicalPlan::filter(filter))) + } } - } + }) } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index ed526e950eb2..139097236aa1 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -17,15 +17,14 @@ //! [`PushDownLimit`] pushes `LIMIT` earlier in the query plan +use crate::{OptimizerConfig, OptimizerRule}; use std::cmp::min; use std::sync::Arc; -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; - use datafusion_common::tree_node::Transformed; use datafusion_common::utils::combine_limit; use datafusion_common::Result; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan}; use datafusion_expr::{lit, FetchType, SkipType}; @@ -53,143 +52,148 @@ impl OptimizerRule for PushDownLimit { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - let LogicalPlan::Limit(mut limit, _) = plan else { - return Ok(Transformed::no(plan)); - }; - - // Currently only rewrite if skip and fetch are both literals - let SkipType::Literal(skip) = limit.get_skip_type()? else { - return Ok(Transformed::no(LogicalPlan::limit(limit))); - }; - let FetchType::Literal(fetch) = limit.get_fetch_type()? else { - return Ok(Transformed::no(LogicalPlan::limit(limit))); - }; - - // Merge the Parent Limit and the Child Limit. - if let LogicalPlan::Limit(child, _) = limit.input.as_ref() { - let SkipType::Literal(child_skip) = child.get_skip_type()? else { + plan.transform_down_with_subqueries(|plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanLimit) + { + return Ok(Transformed::jump(plan)); + } + + let LogicalPlan::Limit(mut limit, _) = plan else { + return Ok(Transformed::no(plan)); + }; + + // Currently only rewrite if skip and fetch are both literals + let SkipType::Literal(skip) = limit.get_skip_type()? else { return Ok(Transformed::no(LogicalPlan::limit(limit))); }; - let FetchType::Literal(child_fetch) = child.get_fetch_type()? else { + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { return Ok(Transformed::no(LogicalPlan::limit(limit))); }; - let (skip, fetch) = combine_limit(skip, fetch, child_skip, child_fetch); - let plan = LogicalPlan::limit(Limit { - skip: Some(Box::new(lit(skip as i64))), - fetch: fetch.map(|f| Box::new(lit(f as i64))), - input: Arc::clone(&child.input), - }); + // Merge the Parent Limit and the Child Limit. + if let LogicalPlan::Limit(child, _) = limit.input.as_ref() { + let SkipType::Literal(child_skip) = child.get_skip_type()? else { + return Ok(Transformed::no(LogicalPlan::limit(limit))); + }; + let FetchType::Literal(child_fetch) = child.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::limit(limit))); + }; - // recursively reapply the rule on the new plan - return self.rewrite(plan, _config); - } + let (skip, fetch) = combine_limit(skip, fetch, child_skip, child_fetch); + let plan = LogicalPlan::limit(Limit { + skip: Some(Box::new(lit(skip as i64))), + fetch: fetch.map(|f| Box::new(lit(f as i64))), + input: Arc::clone(&child.input), + }); + + // recursively reapply the rule on the new plan + return self.rewrite(plan, _config); + } - // no fetch to push, so return the original plan - let Some(fetch) = fetch else { - return Ok(Transformed::no(LogicalPlan::limit(limit))); - }; - - match Arc::unwrap_or_clone(limit.input) { - LogicalPlan::TableScan(mut scan, _) => { - let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; - let new_fetch = scan - .fetch - .map(|x| min(x, rows_needed)) - .or(Some(rows_needed)); - if new_fetch == scan.fetch { - original_limit(skip, fetch, LogicalPlan::table_scan(scan)) - } else { - // push limit into the table scan itself - scan.fetch = scan + // no fetch to push, so return the original plan + let Some(fetch) = fetch else { + return Ok(Transformed::no(LogicalPlan::limit(limit))); + }; + + match Arc::unwrap_or_clone(limit.input) { + LogicalPlan::TableScan(mut scan, _) => { + let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; + let new_fetch = scan .fetch .map(|x| min(x, rows_needed)) .or(Some(rows_needed)); - transformed_limit(skip, fetch, LogicalPlan::table_scan(scan)) + if new_fetch == scan.fetch { + original_limit(skip, fetch, LogicalPlan::table_scan(scan)) + } else { + // push limit into the table scan itself + scan.fetch = scan + .fetch + .map(|x| min(x, rows_needed)) + .or(Some(rows_needed)); + transformed_limit(skip, fetch, LogicalPlan::table_scan(scan)) + } + } + LogicalPlan::Union(mut union, _) => { + // push limits to each input of the union + union.inputs = union + .inputs + .into_iter() + .map(|input| make_arc_limit(0, fetch + skip, input)) + .collect(); + transformed_limit(skip, fetch, LogicalPlan::union(union)) } - } - LogicalPlan::Union(mut union, _) => { - // push limits to each input of the union - union.inputs = union - .inputs - .into_iter() - .map(|input| make_arc_limit(0, fetch + skip, input)) - .collect(); - transformed_limit(skip, fetch, LogicalPlan::union(union)) - } - - LogicalPlan::Join(join, _) => Ok(push_down_join(join, fetch + skip) - .update_data(|join| { - make_limit(skip, fetch, Arc::new(LogicalPlan::join(join))) - })), - LogicalPlan::Sort(mut sort, _) => { - let new_fetch = { - let sort_fetch = skip + fetch; - Some(sort.fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch)) - }; - if new_fetch == sort.fetch { - if skip > 0 { - original_limit(skip, fetch, LogicalPlan::sort(sort)) + LogicalPlan::Join(join, _) => Ok(push_down_join(join, fetch + skip) + .update_data(|join| { + make_limit(skip, fetch, Arc::new(LogicalPlan::join(join))) + })), + + LogicalPlan::Sort(mut sort, _) => { + let new_fetch = { + let sort_fetch = skip + fetch; + Some(sort.fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch)) + }; + if new_fetch == sort.fetch { + if skip > 0 { + original_limit(skip, fetch, LogicalPlan::sort(sort)) + } else { + Ok(Transformed::yes(LogicalPlan::sort(sort))) + } } else { - Ok(Transformed::yes(LogicalPlan::sort(sort))) + sort.fetch = new_fetch; + limit.input = Arc::new(LogicalPlan::sort(sort)); + Ok(Transformed::yes(LogicalPlan::limit(limit))) } - } else { - sort.fetch = new_fetch; - limit.input = Arc::new(LogicalPlan::sort(sort)); - Ok(Transformed::yes(LogicalPlan::limit(limit))) } - } - LogicalPlan::Projection(mut proj, _) => { - // commute - limit.input = Arc::clone(&proj.input); - let new_limit = LogicalPlan::limit(limit); - proj.input = Arc::new(new_limit); - Ok(Transformed::yes(LogicalPlan::projection(proj))) - } - LogicalPlan::SubqueryAlias(mut subquery_alias, _) => { - // commute - limit.input = Arc::clone(&subquery_alias.input); - let new_limit = LogicalPlan::limit(limit); - subquery_alias.input = Arc::new(new_limit); - Ok(Transformed::yes(LogicalPlan::subquery_alias( - subquery_alias, - ))) - } - LogicalPlan::Extension(extension_plan, _) - if extension_plan.node.supports_limit_pushdown() => - { - let new_children = extension_plan - .node - .inputs() - .into_iter() - .map(|child| { - LogicalPlan::limit(Limit { - skip: None, - fetch: Some(Box::new(lit((fetch + skip) as i64))), - input: Arc::new(child.clone()), + LogicalPlan::Projection(mut proj, _) => { + // commute + limit.input = Arc::clone(&proj.input); + let new_limit = LogicalPlan::limit(limit); + proj.input = Arc::new(new_limit); + Ok(Transformed::yes(LogicalPlan::projection(proj))) + } + LogicalPlan::SubqueryAlias(mut subquery_alias, _) => { + // commute + limit.input = Arc::clone(&subquery_alias.input); + let new_limit = LogicalPlan::limit(limit); + subquery_alias.input = Arc::new(new_limit); + Ok(Transformed::yes(LogicalPlan::subquery_alias( + subquery_alias, + ))) + } + LogicalPlan::Extension(extension_plan, _) + if extension_plan.node.supports_limit_pushdown() => + { + let new_children = extension_plan + .node + .inputs() + .into_iter() + .map(|child| { + LogicalPlan::limit(Limit { + skip: None, + fetch: Some(Box::new(lit((fetch + skip) as i64))), + input: Arc::new(child.clone()), + }) }) - }) - .collect::>(); + .collect::>(); - // Create a new extension node with updated inputs - let child_plan = LogicalPlan::extension(extension_plan); - let new_extension = - child_plan.with_new_exprs(child_plan.expressions(), new_children)?; + // Create a new extension node with updated inputs + let child_plan = LogicalPlan::extension(extension_plan); + let new_extension = child_plan + .with_new_exprs(child_plan.expressions(), new_children)?; - transformed_limit(skip, fetch, new_extension) + transformed_limit(skip, fetch, new_extension) + } + input => original_limit(skip, fetch, input), } - input => original_limit(skip, fetch, input), - } + }) } fn name(&self) -> &str { "push_down_limit" } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } } /// Wrap the input plan with a limit node diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 475b58570d4a..d72324e0f660 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -17,16 +17,16 @@ //! [`ReplaceDistinctWithAggregate`] replaces `DISTINCT ...` with `GROUP BY ...` -use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; -use std::sync::Arc; - use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; use datafusion_expr::expr_rewriter::normalize_cols; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::utils::expand_wildcard; use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; +use std::cell::Cell; +use std::sync::Arc; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// @@ -74,115 +74,140 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Distinct(Distinct::All(input), _) => { - let group_expr = expand_wildcard(input.schema(), &input, None)?; - - let field_count = input.schema().fields().len(); - for dep in input.schema().functional_dependencies().iter() { - // If distinct is exactly the same with a previous GROUP BY, we can - // simply remove it: - if dep.source_indices.len() >= field_count - && dep.source_indices[..field_count] - .iter() - .enumerate() - .all(|(idx, f_idx)| idx == *f_idx) - { - return Ok(Transformed::yes(input.as_ref().clone())); - } + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !plan + .stats() + .contains_pattern(LogicalPlanPattern::LogicalPlanDistinct) + { + skip.set(true); + return Ok(Transformed::jump(plan)); + } + + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); } - // Replace with aggregation: - let aggr_plan = LogicalPlan::aggregate(Aggregate::try_new( - input, - group_expr, - vec![], - )?); - Ok(Transformed::yes(aggr_plan)) - } - LogicalPlan::Distinct( - Distinct::On(DistinctOn { - select_expr, - on_expr, - sort_expr, - input, - schema, - }), - _, - ) => { - let expr_cnt = on_expr.len(); - - // Construct the aggregation expression to be used to fetch the selected expressions. - let first_value_udaf: Arc = - config.function_registry().unwrap().udaf("first_value")?; - let aggr_expr = select_expr.into_iter().map(|e| { - if let Some(order_by) = &sort_expr { - first_value_udaf - .call(vec![e]) - .order_by(order_by.clone()) - .build() - // guaranteed to be `Expr::AggregateFunction` - .unwrap() - } else { - first_value_udaf.call(vec![e]) + match plan { + LogicalPlan::Distinct(Distinct::All(input), _) => { + let group_expr = expand_wildcard(input.schema(), &input, None)?; + + let field_count = input.schema().fields().len(); + for dep in input.schema().functional_dependencies().iter() { + // If distinct is exactly the same with a previous GROUP BY, we can + // simply remove it: + if dep.source_indices.len() >= field_count + && dep.source_indices[..field_count] + .iter() + .enumerate() + .all(|(idx, f_idx)| idx == *f_idx) + { + return Ok(Transformed::yes(input.as_ref().clone())); + } + } + + // Replace with aggregation: + let aggr_plan = LogicalPlan::aggregate(Aggregate::try_new( + input, + group_expr, + vec![], + )?); + Ok(Transformed::yes(aggr_plan)) } - }); - - let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?; - let group_expr = normalize_cols(on_expr, input.as_ref())?; - - // Build the aggregation plan - let plan = LogicalPlan::aggregate(Aggregate::try_new( - input, group_expr, aggr_expr, - )?); - // TODO use LogicalPlanBuilder directly rather than recreating the Aggregate - // when https://github.com/apache/datafusion/issues/10485 is available - let lpb = LogicalPlanBuilder::from(plan); - - let plan = if let Some(mut sort_expr) = sort_expr { - // While sort expressions were used in the `FIRST_VALUE` aggregation itself above, - // this on it's own isn't enough to guarantee the proper output order of the grouping - // (`ON`) expression, so we need to sort those as well. - - // truncate the sort_expr to the length of on_expr - sort_expr.truncate(expr_cnt); - - lpb.sort(sort_expr)?.build()? - } else { - lpb.build()? - }; - - // Whereas the aggregation plan by default outputs both the grouping and the aggregation - // expressions, for `DISTINCT ON` we only need to emit the original selection expressions. - - let project_exprs = plan - .schema() - .iter() - .skip(expr_cnt) - .zip(schema.iter()) - .map(|((new_qualifier, new_field), (old_qualifier, old_field))| { - col(Column::from((new_qualifier, new_field))) - .alias_qualified(old_qualifier.cloned(), old_field.name()) - }) - .collect::>(); - - let plan = LogicalPlanBuilder::from(plan) - .project(project_exprs)? - .build()?; - - Ok(Transformed::yes(plan)) - } - _ => Ok(Transformed::no(plan)), - } + LogicalPlan::Distinct( + Distinct::On(DistinctOn { + select_expr, + on_expr, + sort_expr, + input, + schema, + }), + _, + ) => { + let expr_cnt = on_expr.len(); + + // Construct the aggregation expression to be used to fetch the selected expressions. + let first_value_udaf: Arc = + config.function_registry().unwrap().udaf("first_value")?; + let aggr_expr = select_expr.into_iter().map(|e| { + if let Some(order_by) = &sort_expr { + first_value_udaf + .call(vec![e]) + .order_by(order_by.clone()) + .build() + // guaranteed to be `Expr::AggregateFunction` + .unwrap() + } else { + first_value_udaf.call(vec![e]) + } + }); + + let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?; + let group_expr = normalize_cols(on_expr, input.as_ref())?; + + // Build the aggregation plan + let plan = LogicalPlan::aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?); + // TODO use LogicalPlanBuilder directly rather than recreating the Aggregate + // when https://github.com/apache/datafusion/issues/10485 is available + let lpb = LogicalPlanBuilder::from(plan); + + let plan = if let Some(mut sort_expr) = sort_expr { + // While sort expressions were used in the `FIRST_VALUE` aggregation itself above, + // this on it's own isn't enough to guarantee the proper output order of the grouping + // (`ON`) expression, so we need to sort those as well. + + // truncate the sort_expr to the length of on_expr + sort_expr.truncate(expr_cnt); + + lpb.sort(sort_expr)?.build()? + } else { + lpb.build()? + }; + + // Whereas the aggregation plan by default outputs both the grouping and the aggregation + // expressions, for `DISTINCT ON` we only need to emit the original selection expressions. + + let project_exprs = plan + .schema() + .iter() + .skip(expr_cnt) + .zip(schema.iter()) + .map( + |( + (new_qualifier, new_field), + (old_qualifier, old_field), + )| { + col(Column::from((new_qualifier, new_field))) + .alias_qualified( + old_qualifier.cloned(), + old_field.name(), + ) + }, + ) + .collect::>(); + + let plan = LogicalPlanBuilder::from(plan) + .project(project_exprs)? + .build()?; + + Ok(Transformed::yes(plan)) + } + _ => Ok(Transformed::no(plan)), + } + }, + ) } fn name(&self) -> &str { "replace_distinct_aggregate" } - - fn apply_order(&self) -> Option { - Some(BottomUp) - } } #[cfg(test)] diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index f9b247fc9a98..a29cc3c951f5 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -17,14 +17,13 @@ //! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s -use std::collections::{BTreeSet, HashMap}; -use std::ops::Not; -use std::sync::Arc; - use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; -use crate::optimizer::ApplyOrder; use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; +use enumset::enum_set; +use std::collections::{BTreeSet, HashMap}; +use std::ops::Not; +use std::sync::Arc; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ @@ -32,6 +31,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; @@ -79,142 +79,151 @@ impl OptimizerRule for ScalarSubqueryToJoin { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Filter(filter, _) => { - // Optimization: skip the rest of the rule and its copies if - // there are no scalar subqueries - if !contains_scalar_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::filter(filter))); - } + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::LogicalPlanFilter + | LogicalPlanPattern::LogicalPlanProjection + )) { + return Ok(Transformed::jump(plan)); + } - let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( - &filter.predicate, - config.alias_generator(), - )?; + match plan { + LogicalPlan::Filter(filter, _) => { + // Optimization: skip the rest of the rule and its copies if + // there are no scalar subqueries + if !contains_scalar_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::filter(filter))); + } - if subqueries.is_empty() { - return internal_err!("Expected subqueries not found in filter"); - } + let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( + &filter.predicate, + config.alias_generator(), + )?; - // iterate through all subqueries in predicate, turning each into a left join - let mut cur_input = filter.input.as_ref().clone(); - for (subquery, alias) in subqueries { - if let Some((optimized_subquery, expr_check_map)) = - build_join(&subquery, &cur_input, &alias)? - { - if !expr_check_map.is_empty() { - rewrite_expr = rewrite_expr - .transform_up(|expr| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = expr - .try_as_col() - .and_then(|col| expr_check_map.get(&col.name)) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }) - .data()?; + if subqueries.is_empty() { + return internal_err!("Expected subqueries not found in filter"); + } + + // iterate through all subqueries in predicate, turning each into a left join + let mut cur_input = filter.input.as_ref().clone(); + for (subquery, alias) in subqueries { + if let Some((optimized_subquery, expr_check_map)) = + build_join(&subquery, &cur_input, &alias)? + { + if !expr_check_map.is_empty() { + rewrite_expr = rewrite_expr + .transform_up(|expr| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .and_then(|col| expr_check_map.get(&col.name)) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }) + .data()?; + } + cur_input = optimized_subquery; + } else { + // if we can't handle all of the subqueries then bail for now + return Ok(Transformed::no(LogicalPlan::filter(filter))); } - cur_input = optimized_subquery; - } else { - // if we can't handle all of the subqueries then bail for now - return Ok(Transformed::no(LogicalPlan::filter(filter))); } + let new_plan = LogicalPlanBuilder::from(cur_input) + .filter(rewrite_expr)? + .build()?; + Ok(Transformed::yes(new_plan)) } - let new_plan = LogicalPlanBuilder::from(cur_input) - .filter(rewrite_expr)? - .build()?; - Ok(Transformed::yes(new_plan)) - } - LogicalPlan::Projection(projection, _) => { - // Optimization: skip the rest of the rule and its copies if - // there are no scalar subqueries - if !projection.expr.iter().any(contains_scalar_subquery) { - return Ok(Transformed::no(LogicalPlan::projection(projection))); - } + LogicalPlan::Projection(projection, _) => { + // Optimization: skip the rest of the rule and its copies if + // there are no scalar subqueries + if !projection.expr.iter().any(contains_scalar_subquery) { + return Ok(Transformed::no(LogicalPlan::projection(projection))); + } - let mut all_subqueryies = vec![]; - let mut expr_to_rewrite_expr_map = HashMap::new(); - let mut subquery_to_expr_map = HashMap::new(); - for expr in projection.expr.iter() { - let (subqueries, rewrite_exprs) = - self.extract_subquery_exprs(expr, config.alias_generator())?; - for (subquery, _) in &subqueries { - subquery_to_expr_map.insert(subquery.clone(), expr.clone()); + let mut all_subqueryies = vec![]; + let mut expr_to_rewrite_expr_map = HashMap::new(); + let mut subquery_to_expr_map = HashMap::new(); + for expr in projection.expr.iter() { + let (subqueries, rewrite_exprs) = + self.extract_subquery_exprs(expr, config.alias_generator())?; + for (subquery, _) in &subqueries { + subquery_to_expr_map.insert(subquery.clone(), expr.clone()); + } + all_subqueryies.extend(subqueries); + expr_to_rewrite_expr_map.insert(expr, rewrite_exprs); } - all_subqueryies.extend(subqueries); - expr_to_rewrite_expr_map.insert(expr, rewrite_exprs); - } - if all_subqueryies.is_empty() { - return internal_err!("Expected subqueries not found in projection"); - } - // iterate through all subqueries in predicate, turning each into a left join - let mut cur_input = projection.input.as_ref().clone(); - for (subquery, alias) in all_subqueryies { - if let Some((optimized_subquery, expr_check_map)) = - build_join(&subquery, &cur_input, &alias)? - { - cur_input = optimized_subquery; - if !expr_check_map.is_empty() { - if let Some(expr) = subquery_to_expr_map.get(&subquery) { - if let Some(rewrite_expr) = - expr_to_rewrite_expr_map.get(expr) - { - let new_expr = rewrite_expr - .clone() - .transform_up(|expr| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = - expr.try_as_col().and_then(|col| { - expr_check_map.get(&col.name) - }) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }) - .data()?; - expr_to_rewrite_expr_map.insert(expr, new_expr); + if all_subqueryies.is_empty() { + return internal_err!( + "Expected subqueries not found in projection" + ); + } + // iterate through all subqueries in predicate, turning each into a left join + let mut cur_input = projection.input.as_ref().clone(); + for (subquery, alias) in all_subqueryies { + if let Some((optimized_subquery, expr_check_map)) = + build_join(&subquery, &cur_input, &alias)? + { + cur_input = optimized_subquery; + if !expr_check_map.is_empty() { + if let Some(expr) = subquery_to_expr_map.get(&subquery) { + if let Some(rewrite_expr) = + expr_to_rewrite_expr_map.get(expr) + { + let new_expr = rewrite_expr + .clone() + .transform_up(|expr| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = + expr.try_as_col().and_then(|col| { + expr_check_map.get(&col.name) + }) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }) + .data()?; + expr_to_rewrite_expr_map.insert(expr, new_expr); + } } } + } else { + // if we can't handle all of the subqueries then bail for now + return Ok(Transformed::no(LogicalPlan::projection( + projection, + ))); } - } else { - // if we can't handle all of the subqueries then bail for now - return Ok(Transformed::no(LogicalPlan::projection(projection))); } - } - let mut proj_exprs = vec![]; - for expr in projection.expr.iter() { - let old_expr_name = expr.schema_name().to_string(); - let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap(); - let new_expr_name = new_expr.schema_name().to_string(); - if new_expr_name != old_expr_name { - proj_exprs.push(new_expr.clone().alias(old_expr_name)) - } else { - proj_exprs.push(new_expr.clone()); + let mut proj_exprs = vec![]; + for expr in projection.expr.iter() { + let old_expr_name = expr.schema_name().to_string(); + let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap(); + let new_expr_name = new_expr.schema_name().to_string(); + if new_expr_name != old_expr_name { + proj_exprs.push(new_expr.clone().alias(old_expr_name)) + } else { + proj_exprs.push(new_expr.clone()); + } } + let new_plan = LogicalPlanBuilder::from(cur_input) + .project(proj_exprs)? + .build()?; + Ok(Transformed::yes(new_plan)) } - let new_plan = LogicalPlanBuilder::from(cur_input) - .project(proj_exprs)? - .build()?; - Ok(Transformed::yes(new_plan)) - } - plan => Ok(Transformed::no(plan)), - } + plan => Ok(Transformed::no(plan)), + } + }) } fn name(&self) -> &str { "scalar_subquery_to_join" } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } } /// Returns true if the expression has a scalar subquery somewhere in it diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 33dda04444c6..0dbcbd06910c 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -43,6 +43,7 @@ use datafusion_expr::{ utils::{iter_conjunction, iter_conjunction_owned}, }; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; +use enumset::enum_set; use indexmap::IndexSet; use super::inlist_simplifier::ShortenInListSimplifier; @@ -51,6 +52,7 @@ use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::SimplifyInfo; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use regex::Regex; /// This structure handles API for expression simplification @@ -409,18 +411,40 @@ impl ExprSimplifier { /// /// ` ` is rewritten so that the name of `col1` sorts higher /// than `col2` (`a > b` would be canonicalized to `b < a`) -struct Canonicalizer {} +struct Canonicalizer { + skip: bool, +} impl Canonicalizer { fn new() -> Self { - Self {} + Self { skip: false } } } impl TreeNodeRewriter for Canonicalizer { type Node = Expr; + fn f_down(&mut self, node: Self::Node) -> Result> { + if !(node + .stats() + .contains_pattern(LogicalPlanPattern::ExprBinaryExpr) + && node.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprColumn | LogicalPlanPattern::ExprLiteral + ))) + { + self.skip = true; + return Ok(Transformed::jump(node)); + } + + Ok(Transformed::no(node)) + } + fn f_up(&mut self, expr: Expr) -> Result> { + if self.skip { + self.skip = false; + return Ok(Transformed::no(expr)); + } + let Expr::BinaryExpr(BinaryExpr { left, op, right }, _) = expr else { return Ok(Transformed::no(expr)); }; @@ -701,19 +725,49 @@ impl<'a> ConstEvaluator<'a> { /// * `expr = null` and `expr != null` to `null` struct Simplifier<'a, S> { info: &'a S, + skip: bool, } impl<'a, S> Simplifier<'a, S> { pub fn new(info: &'a S) -> Self { - Self { info } + Self { info, skip: false } } } impl TreeNodeRewriter for Simplifier<'_, S> { type Node = Expr; + fn f_down(&mut self, node: Self::Node) -> Result> { + if !node.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprBinaryExpr + | LogicalPlanPattern::ExprNot + | LogicalPlanPattern::ExprNegative + | LogicalPlanPattern::ExprCase + | LogicalPlanPattern::ExprScalarFunction + | LogicalPlanPattern::ExprAggregateFunction + | LogicalPlanPattern::ExprWindowFunction + | LogicalPlanPattern::ExprBetween + | LogicalPlanPattern::ExprIsNotNull + | LogicalPlanPattern::ExprIsNull + | LogicalPlanPattern::ExprIsNotUnknown + | LogicalPlanPattern::ExprIsUnknown + | LogicalPlanPattern::ExprInList + | LogicalPlanPattern::ExprLike + )) { + self.skip = true; + return Ok(Transformed::jump(node)); + } + + Ok(Transformed::no(node)) + } + /// rewrite the expression simplifying any constant expressions fn f_up(&mut self, expr: Expr) -> Result> { + if self.skip { + self.skip = false; + return Ok(Transformed::no(expr)); + } + use datafusion_expr::Operator::{ And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor, Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch, diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 579a365bb57c..04131e86c941 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -22,7 +22,9 @@ use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; +use enumset::enum_set; use std::{borrow::Cow, collections::HashMap}; /// Rewrite expressions to incorporate guarantees. @@ -40,6 +42,7 @@ use std::{borrow::Cow, collections::HashMap}; /// [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees pub struct GuaranteeRewriter<'a> { guarantees: HashMap<&'a Expr, &'a NullableInterval>, + skip: bool, } impl<'a> GuaranteeRewriter<'a> { @@ -52,6 +55,7 @@ impl<'a> GuaranteeRewriter<'a> { // issue is fixed. #[allow(clippy::map_identity)] guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), + skip: false, } } } @@ -59,7 +63,32 @@ impl<'a> GuaranteeRewriter<'a> { impl TreeNodeRewriter for GuaranteeRewriter<'_> { type Node = Expr; + fn f_down(&mut self, node: Self::Node) -> Result> { + if !node.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprIsNull + | LogicalPlanPattern::ExprIsNotNull + | LogicalPlanPattern::ExprBetween + | LogicalPlanPattern::ExprBinaryExpr + | LogicalPlanPattern::ExprColumn + | LogicalPlanPattern::ExprInList + )) { + self.skip = true; + return Ok(Transformed::jump(node)); + } + + Ok(Transformed::no(node)) + } + fn f_up(&mut self, expr: Expr) -> Result> { + if self.skip { + self.skip = false; + return Ok(Transformed::no(expr)); + } + + if self.guarantees.is_empty() { + return Ok(Transformed::no(expr)); + } + match &expr { Expr::IsNull(inner, _) => match self.guarantees.get(inner.as_ref()) { Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))), diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 4f6a7832532c..1fa04040b638 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -22,20 +22,40 @@ use super::THRESHOLD_INLINE_INLIST; use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_expr::expr::InList; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::Expr; -pub(super) struct ShortenInListSimplifier {} +pub(super) struct ShortenInListSimplifier { + skip: bool, +} impl ShortenInListSimplifier { pub(super) fn new() -> Self { - Self {} + Self { skip: false } } } impl TreeNodeRewriter for ShortenInListSimplifier { type Node = Expr; + fn f_down(&mut self, node: Self::Node) -> Result> { + if !node + .stats() + .contains_pattern(LogicalPlanPattern::ExprInList) + { + self.skip = true; + return Ok(Transformed::jump(node)); + } + + Ok(Transformed::no(node)) + } + fn f_up(&mut self, expr: Expr) -> Result> { + if self.skip { + self.skip = false; + return Ok(Transformed::no(expr)); + } + // if expr is a single column reference: // expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C) if let Expr::InList( diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 1cefe352fa93..d96e8447f106 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -19,13 +19,13 @@ use std::sync::Arc; -use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{ internal_err, tree_node::Transformed, DataFusionError, HashSet, Result, }; use datafusion_expr::builder::project; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::{ col, expr::AggregateFunction, @@ -109,10 +109,6 @@ impl OptimizerRule for SingleDistinctToGroupBy { "single_distinct_aggregation_to_group_by" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - fn supports_rewrite(&self) -> bool { true } @@ -122,7 +118,12 @@ impl OptimizerRule for SingleDistinctToGroupBy { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result, DataFusionError> { - match plan { + plan.transform_down_with_subqueries(|plan| { + if !plan.stats().contains_pattern(LogicalPlanPattern::LogicalPlanAggregate) { + return Ok(Transformed::jump(plan)); + } + + match plan { LogicalPlan::Aggregate( Aggregate { input, @@ -277,7 +278,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { Ok(Transformed::yes(project(outer_aggr, alias_expr)?)) } _ => Ok(Transformed::no(plan)), - } + }}) } } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index c8b7ddd0b79c..7caf1be2aecd 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -17,11 +17,11 @@ //! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)` +use std::cell::Cell; use std::cmp::Ordering; use std::mem; use std::sync::Arc; -use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use crate::utils::NamePreserver; @@ -32,8 +32,10 @@ use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; +use datafusion_expr::logical_plan::tree_node::LogicalPlanPattern; use datafusion_expr::utils::merge_schema; use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan}; +use enumset::enum_set; /// [`UnwrapCastInComparison`] attempts to remove casts from /// comparisons to literals ([`ScalarValue`]s) by applying the casts @@ -86,10 +88,6 @@ impl OptimizerRule for UnwrapCastInComparison { "unwrap_cast_in_comparison" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } - fn supports_rewrite(&self) -> bool { true } @@ -99,39 +97,94 @@ impl OptimizerRule for UnwrapCastInComparison { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - let mut schema = merge_schema(&plan.inputs()); - - if let LogicalPlan::TableScan(ts, _) = &plan { - let source_schema = DFSchema::try_from_qualified_schema( - ts.table_name.clone(), - &ts.source.schema(), - )?; - schema.merge(&source_schema); - } + let skip = Cell::new(false); + plan.transform_down_up_with_subqueries( + |plan| { + if !(plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprBinaryExpr | LogicalPlanPattern::ExprInList + )) && plan.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprTryCast | LogicalPlanPattern::ExprCast + )) && plan + .stats() + .contains_pattern(LogicalPlanPattern::ExprLiteral)) + { + skip.set(true); + return Ok(Transformed::jump(plan)); + } + + Ok(Transformed::no(plan)) + }, + |plan| { + if skip.get() { + skip.set(false); + return Ok(Transformed::no(plan)); + } + + let mut schema = merge_schema(&plan.inputs()); + + if let LogicalPlan::TableScan(ts, _) = &plan { + let source_schema = DFSchema::try_from_qualified_schema( + ts.table_name.clone(), + &ts.source.schema(), + )?; + schema.merge(&source_schema); + } - schema.merge(plan.schema()); + schema.merge(plan.schema()); - let mut expr_rewriter = UnwrapCastExprRewriter { - schema: Arc::new(schema), - }; + let mut expr_rewriter = UnwrapCastExprRewriter::new(Arc::new(schema)); - let name_preserver = NamePreserver::new(&plan); - plan.map_expressions(|expr| { - let original_name = name_preserver.save(&expr); - expr.rewrite(&mut expr_rewriter) - .map(|transformed| transformed.update_data(|e| original_name.restore(e))) - }) + let name_preserver = NamePreserver::new(&plan); + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr); + expr.rewrite(&mut expr_rewriter).map(|transformed| { + transformed.update_data(|e| original_name.restore(e)) + }) + }) + }, + ) } } struct UnwrapCastExprRewriter { schema: DFSchemaRef, + skip: bool, +} + +impl UnwrapCastExprRewriter { + fn new(schema: DFSchemaRef) -> Self { + Self { + schema, + skip: false, + } + } } impl TreeNodeRewriter for UnwrapCastExprRewriter { type Node = Expr; + fn f_down(&mut self, node: Self::Node) -> Result> { + if !(node.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprBinaryExpr | LogicalPlanPattern::ExprInList + )) && node.stats().contains_any_patterns(enum_set!( + LogicalPlanPattern::ExprTryCast | LogicalPlanPattern::ExprCast + )) && node + .stats() + .contains_pattern(LogicalPlanPattern::ExprLiteral)) + { + self.skip = true; + return Ok(Transformed::jump(node)); + } + + Ok(Transformed::no(node)) + } + fn f_up(&mut self, mut expr: Expr) -> Result> { + if self.skip { + self.skip = false; + return Ok(Transformed::no(expr)); + } + match &mut expr { // For case: // try_cast/cast(expr as data_type) op literal @@ -179,9 +232,12 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { else { return Ok(Transformed::no(expr)); }; - **left = lit(value); // unwrap the cast/try_cast for the right expr - **right = mem::take(right_expr); + expr = Expr::binary_expr(BinaryExpr { + left: Box::new(lit(value)), + op: *op, + right: Box::new(mem::take(right_expr)), + }); Ok(Transformed::yes(expr)) } } @@ -216,8 +272,11 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { return Ok(Transformed::no(expr)); }; // unwrap the cast/try_cast for the left expr - **left = mem::take(left_expr); - **right = lit(value); + expr = Expr::binary_expr(BinaryExpr { + left: mem::take(left_expr), + op: *op, + right: Box::new(lit(value)), + }); Ok(Transformed::yes(expr)) } } @@ -229,7 +288,9 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) Expr::InList( InList { - expr: left, list, .. + expr: left, + list, + negated, }, _, ) => { @@ -285,8 +346,11 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { .collect::>>() else { return Ok(Transformed::no(expr)) }; - **left = mem::take(left_expr); - *list = right_exprs; + expr = Expr::_in_list(InList { + expr: Box::new(mem::take(left_expr)), + list: right_exprs, + negated: *negated, + }); Ok(Transformed::yes(expr)) } // TODO: handle other expr type and dfs visit them @@ -858,9 +922,7 @@ mod tests { } fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { - let mut expr_rewriter = UnwrapCastExprRewriter { - schema: Arc::clone(schema), - }; + let mut expr_rewriter = UnwrapCastExprRewriter::new(Arc::clone(schema)); expr.rewrite(&mut expr_rewriter).data().unwrap() } diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index f93180ffbeb3..da7da7608d8c 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -95,7 +95,7 @@ SELECT arrow_cast('1', 'Int16') query error SELECT arrow_cast('1') -query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got Literal\(Int64\(43\), LogicalPlanStats\) +query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got Literal\(Int64\(43\), LogicalPlanStats \{ patterns: EnumSet\(ExprLiteral\) \}\) SELECT arrow_cast('1', 43) query error Error unrecognized word: unknown diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 591322b372c9..a923c5473b89 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -429,7 +429,7 @@ logical_plan 04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 05)--------TableScan: t2 06)--TableScan: t1 projection=[a] -physical_plan_error This feature is not implemented: Physical plan does not support logical expression Exists(Exists { subquery: , negated: false }, LogicalPlanStats) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression Exists(Exists { subquery: , negated: false }, LogicalPlanStats { patterns: EnumSet(ExprColumn | ExprLiteral | ExprAggregateFunction | ExprExists | LogicalPlanProjection | LogicalPlanAggregate) }) statement ok drop table t1; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 04f04b83e961..37a9bc46d8c7 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4062,12 +4062,12 @@ logical_plan 07)--------Unnest: lists[unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)))|depth=1] structs[] 08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t1.t1_int) AS Int64)) AS unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int))) 09)------------EmptyRelation -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t1" }), name: "t1_int" }, LogicalPlanStats) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t1" }), name: "t1_int" }, LogicalPlanStats { patterns: EnumSet() }) # Test CROSS JOIN LATERAL syntax (execution) # TODO: https://github.com/apache/datafusion/issues/10048 -query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t1" \}\), name: "t1_int" \}, LogicalPlanStats\) +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t1" \}\), name: "t1_int" \}, LogicalPlanStats \{ patterns: EnumSet\(\) \}\) select t1_id, t1_name, i from join_t1 t1 cross join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); @@ -4085,12 +4085,12 @@ logical_plan 07)--------Unnest: lists[unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)))|depth=1] structs[] 08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t2.t1_int) AS Int64)) AS unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int))) 09)------------EmptyRelation -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t2" }), name: "t1_int" }, LogicalPlanStats) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t2" }), name: "t1_int" }, LogicalPlanStats { patterns: EnumSet() }) # Test INNER JOIN LATERAL syntax (execution) # TODO: https://github.com/apache/datafusion/issues/10048 -query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t2" \}\), name: "t1_int" \}, LogicalPlanStats\) +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t2" \}\), name: "t1_int" \}, LogicalPlanStats \{ patterns: EnumSet\(\) \}\) select t1_id, t1_name, i from join_t1 t2 inner join lateral (select * from unnest(generate_series(1, t1_int))) as series(i) on(t1_id > i); # Test RIGHT JOIN LATERAL syntax (unsupported) diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index ca553fa1ee77..d7293d0ff4f8 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -60,7 +60,7 @@ logical_plan 06)----------Filter: outer_ref(t1.a) = t2.a 07)------------TableScan: t2 08)----TableScan: t1 -physical_plan_error This feature is not implemented: Physical plan does not support logical expression ScalarSubquery(, LogicalPlanStats) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression ScalarSubquery(, LogicalPlanStats { patterns: EnumSet(ExprColumn | ExprBinaryExpr | ExprAggregateFunction | ExprScalarSubquery | LogicalPlanProjection | LogicalPlanFilter | LogicalPlanAggregate) }) # set from other table query TT