From a834fda4e446b7567c04f7ced276bfb1a14ff544 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Wed, 27 Nov 2024 23:08:38 +0800 Subject: [PATCH 1/6] Implement GroupsAccumulator for corr(x,y) --- .../groups_accumulator/accumulate.rs | 182 +++++++++- .../functions-aggregate/src/correlation.rs | 333 +++++++++++++++++- 2 files changed, 513 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index ac4d0e75535e..95fa9e8bee03 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -371,6 +371,87 @@ pub fn accumulate( } } +/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`) +/// +/// This method assumes that for any input record index, if any of the value column +/// is null, or it's filtered out by `opt_filter`, then the record would be ignored. +/// (won't be accumulated by `value_fn`) +pub fn accumulate_multiple( + group_indices: &[usize], + value_columns: &[&PrimitiveArray], + opt_filter: Option<&BooleanArray>, + mut value_fn: F, +) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, &[T::Native]) + Send, +{ + let acc_cols: Vec<&[T::Native]> = value_columns + .iter() + .map(|arr| arr.values().as_ref()) + .collect(); + + // Calculate `valid_indices` to accumulate, non-valid indices are ignored. + // `valid_indices` is a bit mask corresponding to the `group_indices`. An index + // is considered valid if: + // 1. All columns are non-null at this index. + // 2. Not filtered out by `opt_filter` + + // Take AND from all null buffers of `value_columns`. + let mut combined_nulls: Option = None; + + for arr in value_columns.iter() { + if arr.null_count() > 0 { + let nulls = arr + .nulls() + .expect("If null_count() > 0, nulls must be present"); + match combined_nulls { + None => { + combined_nulls = Some(nulls.clone()); + } + Some(ref mut combined) => { + let result = NullBuffer::union(Some(combined), Some(nulls)).unwrap(); + *combined = result.clone(); + } + } + } + } + + // Take AND from previous combined nulls and `opt_filter`. + let valid_indices = match (combined_nulls, opt_filter) { + (None, None) => None, + (None, Some(filter)) => Some(filter.clone()), + (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)), + (Some(nulls), Some(filter)) => { + let combined = nulls.inner() & filter.values(); + Some(BooleanArray::new(combined, None)) + } + }; + + for col in acc_cols.iter() { + assert_eq!(col.len(), group_indices.len()); + } + + match valid_indices { + None => { + for (idx, &group_idx) in group_indices.iter().enumerate() { + // Get `idx`-th row from all value(accumulate) columns + let row_values: Vec<_> = acc_cols.iter().map(|col| col[idx]).collect(); + value_fn(group_idx, &row_values); + } + } + Some(valid_indices) => { + for (idx, &group_idx) in group_indices.iter().enumerate() { + if valid_indices.value(idx) { + // Get `idx`-th row from all value(accumulate) columns + let row_values: Vec<_> = + acc_cols.iter().map(|col| col[idx]).collect(); + value_fn(group_idx, &row_values); + } + } + } + } +} + /// This function is called to update the accumulator state per row /// when the value is not needed (e.g. COUNT) /// @@ -528,7 +609,7 @@ fn initialize_builder( mod test { use super::*; - use arrow::array::UInt32Array; + use arrow::array::{Int32Array, UInt32Array}; use rand::{rngs::ThreadRng, Rng}; use std::collections::HashSet; @@ -940,4 +1021,103 @@ mod test { .collect() } } + + #[test] + fn test_accumulate_multiple_no_nulls_no_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![1, 2, 3, 4]); + let values2 = Int32Array::from(vec![10, 20, 30, 40]); + let value_columns = vec![values1, values2]; + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + None, + |group_idx, values| { + accumulated.push((group_idx, values.to_vec())); + }, + ); + + let expected = vec![ + (0, vec![1, 10]), + (1, vec![2, 20]), + (0, vec![3, 30]), + (1, vec![4, 40]), + ]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_nulls() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]); + let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]); + let value_columns = vec![values1, values2]; + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + None, + |group_idx, values| { + accumulated.push((group_idx, values.to_vec())); + }, + ); + + // Only rows where both columns are non-null should be accumulated + let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![1, 2, 3, 4]); + let values2 = Int32Array::from(vec![10, 20, 30, 40]); + let value_columns = vec![values1, values2]; + + let filter = BooleanArray::from(vec![true, false, true, false]); + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + Some(&filter), + |group_idx, values| { + accumulated.push((group_idx, values.to_vec())); + }, + ); + + // Only rows where filter is true should be accumulated + let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_nulls_and_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]); + let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]); + let value_columns = vec![values1, values2]; + + let filter = BooleanArray::from(vec![true, true, true, false]); + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + Some(&filter), + |group_idx, values| { + accumulated.push((group_idx, values.to_vec())); + }, + ); + + // Only rows where both: + // 1. Filter is true + // 2. Both columns are non-null + // should be accumulated + let expected = vec![(0, vec![1, 10])]; + assert_eq!(accumulated, expected); + } } diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 14124ce46aea..46e7e27dad59 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -22,11 +22,19 @@ use std::fmt::Debug; use std::mem::size_of_val; use std::sync::{Arc, OnceLock}; -use arrow::compute::{and, filter, is_not_null}; +use arrow::array::{ + downcast_array, Array, AsArray, BooleanArray, BooleanBufferBuilder, Float64Array, + UInt64Array, +}; +use arrow::compute::{and, filter, is_not_null, kernels::cast}; +use arrow::datatypes::{Float64Type, UInt64Type}; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple; +use log::debug; use crate::covariance::CovarianceAccumulator; use crate::stddev::StddevAccumulator; @@ -113,6 +121,18 @@ impl AggregateUDFImpl for Correlation { fn documentation(&self) -> Option<&Documentation> { Some(get_corr_doc()) } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + debug!("GroupsAccumulator is created for aggregate function `corr(c1, c2)`"); + Ok(Box::new(CorrelationGroupsAccumulator::new())) + } } static DOCUMENTATION: OnceLock = OnceLock::new(); @@ -263,3 +283,314 @@ impl Accumulator for CorrelationAccumulator { Ok(()) } } + +pub struct CorrelationGroupsAccumulator { + // Number of elements for each group + // This is also used to track nulls: if a group has 0 valid values accumulated, + // final aggregation result will be null. + count: Vec, + // Sum of x values for each group + sum_x: Vec, + // Sum of y + sum_y: Vec, + // Sum of x*y + sum_xy: Vec, + // Sum of x^2 + sum_xx: Vec, + // Sum of y^2 + sum_yy: Vec, +} + +impl CorrelationGroupsAccumulator { + pub fn new() -> Self { + Self { + count: Vec::new(), + sum_x: Vec::new(), + sum_y: Vec::new(), + sum_xy: Vec::new(), + sum_xx: Vec::new(), + sum_yy: Vec::new(), + } + } +} + +/// Specialized version of `accumulate_multiple` for correlation's merge_batch +/// +/// Note: Arrays in `state_arrays` should not have null values, because they are all +/// intermediate states created within the accumulator, instead of inputs from +/// outside. +fn accumulate_correlation_states( + group_indices: &[usize], + state_arrays: ( + &UInt64Array, // count + &Float64Array, // sum_x + &Float64Array, // sum_y + &Float64Array, // sum_xy + &Float64Array, // sum_xx + &Float64Array, // sum_yy + ), + mut value_fn: impl FnMut(usize, u64, &[f64]), +) { + let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays; + + assert_eq!(counts.null_count(), 0); + assert_eq!(sum_x.null_count(), 0); + assert_eq!(sum_y.null_count(), 0); + assert_eq!(sum_xy.null_count(), 0); + assert_eq!(sum_xx.null_count(), 0); + assert_eq!(sum_yy.null_count(), 0); + + let counts_values = counts.values().as_ref(); + let sum_x_values = sum_x.values().as_ref(); + let sum_y_values = sum_y.values().as_ref(); + let sum_xy_values = sum_xy.values().as_ref(); + let sum_xx_values = sum_xx.values().as_ref(); + let sum_yy_values = sum_yy.values().as_ref(); + + let mut row = [0.0; 5]; + for (idx, &group_idx) in group_indices.iter().enumerate() { + row[0] = sum_x_values[idx]; + row[1] = sum_y_values[idx]; + row[2] = sum_xy_values[idx]; + row[3] = sum_xx_values[idx]; + row[4] = sum_yy_values[idx]; + value_fn(group_idx, counts_values[idx], &row); + } +} + +/// GroupsAccumulator implementation for `corr(x, y)` that computes the Pearson correlation coefficient +/// between two numeric columns. +/// +/// Online algorithm for correlation: +/// +/// r = (n * sum_xy - sum_x * sum_y) / sqrt((n * sum_xx - sum_x^2) * (n * sum_yy - sum_y^2)) +/// where: +/// n = number of observations +/// sum_x = sum of x values +/// sum_y = sum of y values +/// sum_xy = sum of (x * y) +/// sum_xx = sum of x^2 values +/// sum_yy = sum of y^2 values +/// +/// Reference: +impl GroupsAccumulator for CorrelationGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.count.resize(total_num_groups, 0); + self.sum_x.resize(total_num_groups, 0.0); + self.sum_y.resize(total_num_groups, 0.0); + self.sum_xy.resize(total_num_groups, 0.0); + self.sum_xx.resize(total_num_groups, 0.0); + self.sum_yy.resize(total_num_groups, 0.0); + + let array_x = &cast(&values[0], &DataType::Float64)?; + let array_x = downcast_array::(array_x); + let array_y = &cast(&values[1], &DataType::Float64)?; + let array_y = downcast_array::(array_y); + + accumulate_multiple( + group_indices, + &[&array_x, &array_y], + opt_filter, + |group_index, values| { + let x = values[0]; + let y = values[1]; + self.count[group_index] += 1; + self.sum_x[group_index] += x; + self.sum_y[group_index] += y; + self.sum_xy[group_index] += x * y; + self.sum_xx[group_index] += x * x; + self.sum_yy[group_index] += y * y; + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // Resize vectors to accommodate total number of groups + self.count.resize(total_num_groups, 0); + self.sum_x.resize(total_num_groups, 0.0); + self.sum_y.resize(total_num_groups, 0.0); + self.sum_xy.resize(total_num_groups, 0.0); + self.sum_xx.resize(total_num_groups, 0.0); + self.sum_yy.resize(total_num_groups, 0.0); + + // Extract arrays from input values + let partial_counts = values[0].as_primitive::(); + let partial_sum_x = values[1].as_primitive::(); + let partial_sum_y = values[2].as_primitive::(); + let partial_sum_xy = values[3].as_primitive::(); + let partial_sum_xx = values[4].as_primitive::(); + let partial_sum_yy = values[5].as_primitive::(); + + assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage"); + + accumulate_correlation_states( + group_indices, + ( + partial_counts, + partial_sum_x, + partial_sum_y, + partial_sum_xy, + partial_sum_xx, + partial_sum_yy, + ), + |group_index, count, values| { + self.count[group_index] += count; + self.sum_x[group_index] += values[0]; + self.sum_y[group_index] += values[1]; + self.sum_xy[group_index] += values[2]; + self.sum_xx[group_index] += values[3]; + self.sum_yy[group_index] += values[4]; + }, + ); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let n = match emit_to { + EmitTo::All => self.count.len(), + EmitTo::First(n) => n, + }; + + let mut values = Vec::with_capacity(n); + let mut nulls = BooleanBufferBuilder::new(n); + + // Notes for `Null` handling: + // - If the `count` state of a group is 0, no valid records are accumulated + // for this group, so the aggregation result is `Null`. + // - Correlation can't be calculated when a group only has 1 record, or when + // the `denominator` state is 0. In these cases, the final aggregation + // result should be `Null` (according to PostgreSQL's behavior). + // + // TODO: Old datafusion implementation returns 0.0 for these invalid cases. + // Update this to match PostgreSQL's behavior. + for i in 0..n { + if self.count[i] < 2 { + // TODO: Evaluate as `Null` (see notes above) + values.push(0.0); + nulls.append(false); + continue; + } + + let count = self.count[i]; + let sum_x = self.sum_x[i]; + let sum_y = self.sum_y[i]; + let sum_xy = self.sum_xy[i]; + let sum_xx = self.sum_xx[i]; + let sum_yy = self.sum_yy[i]; + + let mean_x = sum_x / count as f64; + let mean_y = sum_y / count as f64; + + let numerator = sum_xy - sum_x * mean_y; + let denominator = + ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt(); + + if denominator == 0.0 { + // TODO: Evaluate as `Null` (see notes above) + values.push(0.0); + nulls.append(false); + } else { + values.push(numerator / denominator); + nulls.append(true); + } + } + + Ok(Arc::new(Float64Array::new( + values.into(), + Some(nulls.finish().into()), + ))) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let n = match emit_to { + EmitTo::All => self.count.len(), + EmitTo::First(n) => n, + }; + + Ok(vec![ + Arc::new(UInt64Array::from(self.count[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())), + ]) + } + + fn size(&self) -> usize { + size_of::() + * (self.count.capacity() + + self.sum_x.capacity() + + self.sum_y.capacity() + + self.sum_xy.capacity() + + self.sum_xx.capacity() + + self.sum_yy.capacity()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, UInt64Array}; + + #[test] + fn test_accumulate_correlation_states() { + // Test data + let group_indices = vec![0, 1, 0, 1]; + let counts = UInt64Array::from(vec![1, 2, 3, 4]); + let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]); + let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]); + let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]); + let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]); + let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]); + + let mut accumulated = vec![]; + accumulate_correlation_states( + &group_indices, + (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy), + |group_idx, count, values| { + accumulated.push((group_idx, count, values.to_vec())); + }, + ); + + let expected = vec![ + (0, 1, vec![10.0, 1.0, 10.0, 100.0, 1.0]), + (1, 2, vec![20.0, 2.0, 40.0, 400.0, 4.0]), + (0, 3, vec![30.0, 3.0, 90.0, 900.0, 9.0]), + (1, 4, vec![40.0, 4.0, 160.0, 1600.0, 16.0]), + ]; + assert_eq!(accumulated, expected); + + // Test that function panics with null values + let counts = UInt64Array::from(vec![Some(1), None, Some(3), Some(4)]); + let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]); + let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]); + let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]); + let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]); + let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]); + + let result = std::panic::catch_unwind(|| { + accumulate_correlation_states( + &group_indices, + (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy), + |_, _, _| {}, + ) + }); + assert!(result.is_err()); + } +} From 773b9c56034b43ed6d740cbeb3ba7bef903d3d2e Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Thu, 28 Nov 2024 19:27:44 +0800 Subject: [PATCH 2/6] feedbacks --- .../groups_accumulator/accumulate.rs | 44 ++++++------------- .../functions-aggregate/src/correlation.rs | 10 +---- 2 files changed, 15 insertions(+), 39 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 95fa9e8bee03..38afeda8ccf8 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -385,11 +385,6 @@ pub fn accumulate_multiple( T: ArrowPrimitiveType + Send, F: FnMut(usize, &[T::Native]) + Send, { - let acc_cols: Vec<&[T::Native]> = value_columns - .iter() - .map(|arr| arr.values().as_ref()) - .collect(); - // Calculate `valid_indices` to accumulate, non-valid indices are ignored. // `valid_indices` is a bit mask corresponding to the `group_indices`. An index // is considered valid if: @@ -397,24 +392,10 @@ pub fn accumulate_multiple( // 2. Not filtered out by `opt_filter` // Take AND from all null buffers of `value_columns`. - let mut combined_nulls: Option = None; - - for arr in value_columns.iter() { - if arr.null_count() > 0 { - let nulls = arr - .nulls() - .expect("If null_count() > 0, nulls must be present"); - match combined_nulls { - None => { - combined_nulls = Some(nulls.clone()); - } - Some(ref mut combined) => { - let result = NullBuffer::union(Some(combined), Some(nulls)).unwrap(); - *combined = result.clone(); - } - } - } - } + let combined_nulls = value_columns + .iter() + .map(|arr| arr.nulls()) + .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); // Take AND from previous combined nulls and `opt_filter`. let valid_indices = match (combined_nulls, opt_filter) { @@ -427,7 +408,7 @@ pub fn accumulate_multiple( } }; - for col in acc_cols.iter() { + for col in value_columns.iter() { assert_eq!(col.len(), group_indices.len()); } @@ -435,7 +416,8 @@ pub fn accumulate_multiple( None => { for (idx, &group_idx) in group_indices.iter().enumerate() { // Get `idx`-th row from all value(accumulate) columns - let row_values: Vec<_> = acc_cols.iter().map(|col| col[idx]).collect(); + let row_values: Vec<_> = + value_columns.iter().map(|col| col.value(idx)).collect(); value_fn(group_idx, &row_values); } } @@ -444,7 +426,7 @@ pub fn accumulate_multiple( if valid_indices.value(idx) { // Get `idx`-th row from all value(accumulate) columns let row_values: Vec<_> = - acc_cols.iter().map(|col| col[idx]).collect(); + value_columns.iter().map(|col| col.value(idx)).collect(); value_fn(group_idx, &row_values); } } @@ -1027,7 +1009,7 @@ mod test { let group_indices = vec![0, 1, 0, 1]; let values1 = Int32Array::from(vec![1, 2, 3, 4]); let values2 = Int32Array::from(vec![10, 20, 30, 40]); - let value_columns = vec![values1, values2]; + let value_columns = [values1, values2]; let mut accumulated = vec![]; accumulate_multiple( @@ -1053,7 +1035,7 @@ mod test { let group_indices = vec![0, 1, 0, 1]; let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]); let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]); - let value_columns = vec![values1, values2]; + let value_columns = [values1, values2]; let mut accumulated = vec![]; accumulate_multiple( @@ -1075,7 +1057,7 @@ mod test { let group_indices = vec![0, 1, 0, 1]; let values1 = Int32Array::from(vec![1, 2, 3, 4]); let values2 = Int32Array::from(vec![10, 20, 30, 40]); - let value_columns = vec![values1, values2]; + let value_columns = [values1, values2]; let filter = BooleanArray::from(vec![true, false, true, false]); @@ -1099,7 +1081,7 @@ mod test { let group_indices = vec![0, 1, 0, 1]; let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]); let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]); - let value_columns = vec![values1, values2]; + let value_columns = [values1, values2]; let filter = BooleanArray::from(vec![true, true, true, false]); @@ -1117,7 +1099,7 @@ mod test { // 1. Filter is true // 2. Both columns are non-null // should be accumulated - let expected = vec![(0, vec![1, 10])]; + let expected = [(0, vec![1, 10])]; assert_eq!(accumulated, expected); } } diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 46e7e27dad59..8b1ef72bc3c0 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -284,6 +284,7 @@ impl Accumulator for CorrelationAccumulator { } } +#[derive(Default)] pub struct CorrelationGroupsAccumulator { // Number of elements for each group // This is also used to track nulls: if a group has 0 valid values accumulated, @@ -303,14 +304,7 @@ pub struct CorrelationGroupsAccumulator { impl CorrelationGroupsAccumulator { pub fn new() -> Self { - Self { - count: Vec::new(), - sum_x: Vec::new(), - sum_y: Vec::new(), - sum_xy: Vec::new(), - sum_xx: Vec::new(), - sum_yy: Vec::new(), - } + Default::default() } } From 380ef0a8c21a66ec1d1e419e71ca45da034958fc Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Thu, 28 Nov 2024 20:20:00 +0800 Subject: [PATCH 3/6] fix CI MSRV --- datafusion/functions-aggregate/src/correlation.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 8b1ef72bc3c0..de48f36c46d8 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -527,13 +527,12 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { } fn size(&self) -> usize { - size_of::() - * (self.count.capacity() - + self.sum_x.capacity() - + self.sum_y.capacity() - + self.sum_xy.capacity() - + self.sum_xx.capacity() - + self.sum_yy.capacity()) + size_of_val(&self.count) + + size_of_val(&self.sum_x) + + size_of_val(&self.sum_y) + + size_of_val(&self.sum_xy) + + size_of_val(&self.sum_xx) + + size_of_val(&self.sum_yy) } } From 98cba9183d6085ba68fe0cf33dad6de9062bfdd9 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Wed, 11 Dec 2024 00:35:41 +0800 Subject: [PATCH 4/6] review --- .../groups_accumulator/accumulate.rs | 8 ++++--- .../functions-aggregate/src/correlation.rs | 21 +++++++++---------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 38afeda8ccf8..0cfdc0836ad1 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -394,8 +394,10 @@ pub fn accumulate_multiple( // Take AND from all null buffers of `value_columns`. let combined_nulls = value_columns .iter() - .map(|arr| arr.nulls()) - .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); + .map(|arr| arr.logical_nulls()) + .fold(None, |acc, nulls| { + NullBuffer::union(acc.as_ref(), nulls.as_ref()) + }); // Take AND from previous combined nulls and `opt_filter`. let valid_indices = match (combined_nulls, opt_filter) { @@ -409,7 +411,7 @@ pub fn accumulate_multiple( }; for col in value_columns.iter() { - assert_eq!(col.len(), group_indices.len()); + debug_assert_eq!(col.len(), group_indices.len()); } match valid_indices { diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index de48f36c46d8..d35630e5d94f 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -26,7 +26,7 @@ use arrow::array::{ downcast_array, Array, AsArray, BooleanArray, BooleanBufferBuilder, Float64Array, UInt64Array, }; -use arrow::compute::{and, filter, is_not_null, kernels::cast}; +use arrow::compute::{and, filter, is_not_null}; use arrow::datatypes::{Float64Type, UInt64Type}; use arrow::{ array::ArrayRef, @@ -341,13 +341,14 @@ fn accumulate_correlation_states( let sum_xx_values = sum_xx.values().as_ref(); let sum_yy_values = sum_yy.values().as_ref(); - let mut row = [0.0; 5]; for (idx, &group_idx) in group_indices.iter().enumerate() { - row[0] = sum_x_values[idx]; - row[1] = sum_y_values[idx]; - row[2] = sum_xy_values[idx]; - row[3] = sum_xx_values[idx]; - row[4] = sum_yy_values[idx]; + let row = [ + sum_x_values[idx], + sum_y_values[idx], + sum_xy_values[idx], + sum_xx_values[idx], + sum_yy_values[idx], + ]; value_fn(group_idx, counts_values[idx], &row); } } @@ -382,10 +383,8 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { self.sum_xx.resize(total_num_groups, 0.0); self.sum_yy.resize(total_num_groups, 0.0); - let array_x = &cast(&values[0], &DataType::Float64)?; - let array_x = downcast_array::(array_x); - let array_y = &cast(&values[1], &DataType::Float64)?; - let array_y = downcast_array::(array_y); + let array_x = downcast_array::(&values[0]); + let array_y = downcast_array::(&values[1]); accumulate_multiple( group_indices, From 66fb41e821a6c04bd9f12727a15a63c7fe789daa Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Wed, 11 Dec 2024 00:57:20 +0800 Subject: [PATCH 5/6] avoid collect in accumulation --- .../groups_accumulator/accumulate.rs | 48 +++++++++++-------- .../functions-aggregate/src/correlation.rs | 6 +-- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 0cfdc0836ad1..e629e99e1657 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -376,6 +376,16 @@ pub fn accumulate( /// This method assumes that for any input record index, if any of the value column /// is null, or it's filtered out by `opt_filter`, then the record would be ignored. /// (won't be accumulated by `value_fn`) +/// +/// # Arguments +/// +/// * `group_indices` - To which groups do the rows in `value_columns` belong +/// * `value_columns` - The input arrays to accumulate +/// * `opt_filter` - Optional filter array. If present, only rows where filter is `Some(true)` are included +/// * `value_fn` - Callback function for each valid row, with parameters: +/// * `group_idx`: The group index for the current row +/// * `batch_idx`: The index of the current row in the input arrays +/// * `columns`: Reference to all input arrays for accessing values pub fn accumulate_multiple( group_indices: &[usize], value_columns: &[&PrimitiveArray], @@ -383,7 +393,7 @@ pub fn accumulate_multiple( mut value_fn: F, ) where T: ArrowPrimitiveType + Send, - F: FnMut(usize, &[T::Native]) + Send, + F: FnMut(usize, usize, &[&PrimitiveArray]) + Send, { // Calculate `valid_indices` to accumulate, non-valid indices are ignored. // `valid_indices` is a bit mask corresponding to the `group_indices`. An index @@ -416,20 +426,14 @@ pub fn accumulate_multiple( match valid_indices { None => { - for (idx, &group_idx) in group_indices.iter().enumerate() { - // Get `idx`-th row from all value(accumulate) columns - let row_values: Vec<_> = - value_columns.iter().map(|col| col.value(idx)).collect(); - value_fn(group_idx, &row_values); + for (batch_idx, &group_idx) in group_indices.iter().enumerate() { + value_fn(group_idx, batch_idx, value_columns); } } Some(valid_indices) => { - for (idx, &group_idx) in group_indices.iter().enumerate() { - if valid_indices.value(idx) { - // Get `idx`-th row from all value(accumulate) columns - let row_values: Vec<_> = - value_columns.iter().map(|col| col.value(idx)).collect(); - value_fn(group_idx, &row_values); + for (batch_idx, &group_idx) in group_indices.iter().enumerate() { + if valid_indices.value(batch_idx) { + value_fn(group_idx, batch_idx, value_columns); } } } @@ -1018,8 +1022,9 @@ mod test { &group_indices, &value_columns.iter().collect::>(), None, - |group_idx, values| { - accumulated.push((group_idx, values.to_vec())); + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); }, ); @@ -1044,8 +1049,9 @@ mod test { &group_indices, &value_columns.iter().collect::>(), None, - |group_idx, values| { - accumulated.push((group_idx, values.to_vec())); + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); }, ); @@ -1068,8 +1074,9 @@ mod test { &group_indices, &value_columns.iter().collect::>(), Some(&filter), - |group_idx, values| { - accumulated.push((group_idx, values.to_vec())); + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); }, ); @@ -1092,8 +1099,9 @@ mod test { &group_indices, &value_columns.iter().collect::>(), Some(&filter), - |group_idx, values| { - accumulated.push((group_idx, values.to_vec())); + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); }, ); diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index d35630e5d94f..40fbebf5bd1d 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -390,9 +390,9 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { group_indices, &[&array_x, &array_y], opt_filter, - |group_index, values| { - let x = values[0]; - let y = values[1]; + |group_index, batch_index, columns| { + let x = columns[0].value(batch_index); + let y = columns[1].value(batch_index); self.count[group_index] += 1; self.sum_x[group_index] += x; self.sum_y[group_index] += y; From 8c84406776109a5257fb93af65078105cdff1fcc Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Wed, 11 Dec 2024 01:19:28 +0800 Subject: [PATCH 6/6] add back cast --- datafusion/functions-aggregate/src/correlation.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 40fbebf5bd1d..935546b4b549 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -26,7 +26,7 @@ use arrow::array::{ downcast_array, Array, AsArray, BooleanArray, BooleanBufferBuilder, Float64Array, UInt64Array, }; -use arrow::compute::{and, filter, is_not_null}; +use arrow::compute::{and, filter, is_not_null, kernels::cast}; use arrow::datatypes::{Float64Type, UInt64Type}; use arrow::{ array::ArrayRef, @@ -383,8 +383,10 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { self.sum_xx.resize(total_num_groups, 0.0); self.sum_yy.resize(total_num_groups, 0.0); - let array_x = downcast_array::(&values[0]); - let array_y = downcast_array::(&values[1]); + let array_x = &cast(&values[0], &DataType::Float64)?; + let array_x = downcast_array::(array_x); + let array_y = &cast(&values[1], &DataType::Float64)?; + let array_y = downcast_array::(array_y); accumulate_multiple( group_indices,