diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index ac4d0e75535e..e629e99e1657 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -371,6 +371,75 @@ pub fn accumulate( } } +/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`) +/// +/// This method assumes that for any input record index, if any of the value column +/// is null, or it's filtered out by `opt_filter`, then the record would be ignored. +/// (won't be accumulated by `value_fn`) +/// +/// # Arguments +/// +/// * `group_indices` - To which groups do the rows in `value_columns` belong +/// * `value_columns` - The input arrays to accumulate +/// * `opt_filter` - Optional filter array. If present, only rows where filter is `Some(true)` are included +/// * `value_fn` - Callback function for each valid row, with parameters: +/// * `group_idx`: The group index for the current row +/// * `batch_idx`: The index of the current row in the input arrays +/// * `columns`: Reference to all input arrays for accessing values +pub fn accumulate_multiple( + group_indices: &[usize], + value_columns: &[&PrimitiveArray], + opt_filter: Option<&BooleanArray>, + mut value_fn: F, +) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, usize, &[&PrimitiveArray]) + Send, +{ + // Calculate `valid_indices` to accumulate, non-valid indices are ignored. + // `valid_indices` is a bit mask corresponding to the `group_indices`. An index + // is considered valid if: + // 1. All columns are non-null at this index. + // 2. Not filtered out by `opt_filter` + + // Take AND from all null buffers of `value_columns`. + let combined_nulls = value_columns + .iter() + .map(|arr| arr.logical_nulls()) + .fold(None, |acc, nulls| { + NullBuffer::union(acc.as_ref(), nulls.as_ref()) + }); + + // Take AND from previous combined nulls and `opt_filter`. + let valid_indices = match (combined_nulls, opt_filter) { + (None, None) => None, + (None, Some(filter)) => Some(filter.clone()), + (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)), + (Some(nulls), Some(filter)) => { + let combined = nulls.inner() & filter.values(); + Some(BooleanArray::new(combined, None)) + } + }; + + for col in value_columns.iter() { + debug_assert_eq!(col.len(), group_indices.len()); + } + + match valid_indices { + None => { + for (batch_idx, &group_idx) in group_indices.iter().enumerate() { + value_fn(group_idx, batch_idx, value_columns); + } + } + Some(valid_indices) => { + for (batch_idx, &group_idx) in group_indices.iter().enumerate() { + if valid_indices.value(batch_idx) { + value_fn(group_idx, batch_idx, value_columns); + } + } + } + } +} + /// This function is called to update the accumulator state per row /// when the value is not needed (e.g. COUNT) /// @@ -528,7 +597,7 @@ fn initialize_builder( mod test { use super::*; - use arrow::array::UInt32Array; + use arrow::array::{Int32Array, UInt32Array}; use rand::{rngs::ThreadRng, Rng}; use std::collections::HashSet; @@ -940,4 +1009,107 @@ mod test { .collect() } } + + #[test] + fn test_accumulate_multiple_no_nulls_no_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![1, 2, 3, 4]); + let values2 = Int32Array::from(vec![10, 20, 30, 40]); + let value_columns = [values1, values2]; + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + None, + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + let expected = vec![ + (0, vec![1, 10]), + (1, vec![2, 20]), + (0, vec![3, 30]), + (1, vec![4, 40]), + ]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_nulls() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]); + let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]); + let value_columns = [values1, values2]; + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + None, + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + // Only rows where both columns are non-null should be accumulated + let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![1, 2, 3, 4]); + let values2 = Int32Array::from(vec![10, 20, 30, 40]); + let value_columns = [values1, values2]; + + let filter = BooleanArray::from(vec![true, false, true, false]); + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + Some(&filter), + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + // Only rows where filter is true should be accumulated + let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_nulls_and_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]); + let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]); + let value_columns = [values1, values2]; + + let filter = BooleanArray::from(vec![true, true, true, false]); + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + Some(&filter), + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + // Only rows where both: + // 1. Filter is true + // 2. Both columns are non-null + // should be accumulated + let expected = [(0, vec![1, 10])]; + assert_eq!(accumulated, expected); + } } diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 14124ce46aea..935546b4b549 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -22,11 +22,19 @@ use std::fmt::Debug; use std::mem::size_of_val; use std::sync::{Arc, OnceLock}; -use arrow::compute::{and, filter, is_not_null}; +use arrow::array::{ + downcast_array, Array, AsArray, BooleanArray, BooleanBufferBuilder, Float64Array, + UInt64Array, +}; +use arrow::compute::{and, filter, is_not_null, kernels::cast}; +use arrow::datatypes::{Float64Type, UInt64Type}; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple; +use log::debug; use crate::covariance::CovarianceAccumulator; use crate::stddev::StddevAccumulator; @@ -113,6 +121,18 @@ impl AggregateUDFImpl for Correlation { fn documentation(&self) -> Option<&Documentation> { Some(get_corr_doc()) } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + debug!("GroupsAccumulator is created for aggregate function `corr(c1, c2)`"); + Ok(Box::new(CorrelationGroupsAccumulator::new())) + } } static DOCUMENTATION: OnceLock = OnceLock::new(); @@ -263,3 +283,308 @@ impl Accumulator for CorrelationAccumulator { Ok(()) } } + +#[derive(Default)] +pub struct CorrelationGroupsAccumulator { + // Number of elements for each group + // This is also used to track nulls: if a group has 0 valid values accumulated, + // final aggregation result will be null. + count: Vec, + // Sum of x values for each group + sum_x: Vec, + // Sum of y + sum_y: Vec, + // Sum of x*y + sum_xy: Vec, + // Sum of x^2 + sum_xx: Vec, + // Sum of y^2 + sum_yy: Vec, +} + +impl CorrelationGroupsAccumulator { + pub fn new() -> Self { + Default::default() + } +} + +/// Specialized version of `accumulate_multiple` for correlation's merge_batch +/// +/// Note: Arrays in `state_arrays` should not have null values, because they are all +/// intermediate states created within the accumulator, instead of inputs from +/// outside. +fn accumulate_correlation_states( + group_indices: &[usize], + state_arrays: ( + &UInt64Array, // count + &Float64Array, // sum_x + &Float64Array, // sum_y + &Float64Array, // sum_xy + &Float64Array, // sum_xx + &Float64Array, // sum_yy + ), + mut value_fn: impl FnMut(usize, u64, &[f64]), +) { + let (counts, sum_x, sum_y, sum_xy, sum_xx, sum_yy) = state_arrays; + + assert_eq!(counts.null_count(), 0); + assert_eq!(sum_x.null_count(), 0); + assert_eq!(sum_y.null_count(), 0); + assert_eq!(sum_xy.null_count(), 0); + assert_eq!(sum_xx.null_count(), 0); + assert_eq!(sum_yy.null_count(), 0); + + let counts_values = counts.values().as_ref(); + let sum_x_values = sum_x.values().as_ref(); + let sum_y_values = sum_y.values().as_ref(); + let sum_xy_values = sum_xy.values().as_ref(); + let sum_xx_values = sum_xx.values().as_ref(); + let sum_yy_values = sum_yy.values().as_ref(); + + for (idx, &group_idx) in group_indices.iter().enumerate() { + let row = [ + sum_x_values[idx], + sum_y_values[idx], + sum_xy_values[idx], + sum_xx_values[idx], + sum_yy_values[idx], + ]; + value_fn(group_idx, counts_values[idx], &row); + } +} + +/// GroupsAccumulator implementation for `corr(x, y)` that computes the Pearson correlation coefficient +/// between two numeric columns. +/// +/// Online algorithm for correlation: +/// +/// r = (n * sum_xy - sum_x * sum_y) / sqrt((n * sum_xx - sum_x^2) * (n * sum_yy - sum_y^2)) +/// where: +/// n = number of observations +/// sum_x = sum of x values +/// sum_y = sum of y values +/// sum_xy = sum of (x * y) +/// sum_xx = sum of x^2 values +/// sum_yy = sum of y^2 values +/// +/// Reference: +impl GroupsAccumulator for CorrelationGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.count.resize(total_num_groups, 0); + self.sum_x.resize(total_num_groups, 0.0); + self.sum_y.resize(total_num_groups, 0.0); + self.sum_xy.resize(total_num_groups, 0.0); + self.sum_xx.resize(total_num_groups, 0.0); + self.sum_yy.resize(total_num_groups, 0.0); + + let array_x = &cast(&values[0], &DataType::Float64)?; + let array_x = downcast_array::(array_x); + let array_y = &cast(&values[1], &DataType::Float64)?; + let array_y = downcast_array::(array_y); + + accumulate_multiple( + group_indices, + &[&array_x, &array_y], + opt_filter, + |group_index, batch_index, columns| { + let x = columns[0].value(batch_index); + let y = columns[1].value(batch_index); + self.count[group_index] += 1; + self.sum_x[group_index] += x; + self.sum_y[group_index] += y; + self.sum_xy[group_index] += x * y; + self.sum_xx[group_index] += x * x; + self.sum_yy[group_index] += y * y; + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // Resize vectors to accommodate total number of groups + self.count.resize(total_num_groups, 0); + self.sum_x.resize(total_num_groups, 0.0); + self.sum_y.resize(total_num_groups, 0.0); + self.sum_xy.resize(total_num_groups, 0.0); + self.sum_xx.resize(total_num_groups, 0.0); + self.sum_yy.resize(total_num_groups, 0.0); + + // Extract arrays from input values + let partial_counts = values[0].as_primitive::(); + let partial_sum_x = values[1].as_primitive::(); + let partial_sum_y = values[2].as_primitive::(); + let partial_sum_xy = values[3].as_primitive::(); + let partial_sum_xx = values[4].as_primitive::(); + let partial_sum_yy = values[5].as_primitive::(); + + assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage"); + + accumulate_correlation_states( + group_indices, + ( + partial_counts, + partial_sum_x, + partial_sum_y, + partial_sum_xy, + partial_sum_xx, + partial_sum_yy, + ), + |group_index, count, values| { + self.count[group_index] += count; + self.sum_x[group_index] += values[0]; + self.sum_y[group_index] += values[1]; + self.sum_xy[group_index] += values[2]; + self.sum_xx[group_index] += values[3]; + self.sum_yy[group_index] += values[4]; + }, + ); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let n = match emit_to { + EmitTo::All => self.count.len(), + EmitTo::First(n) => n, + }; + + let mut values = Vec::with_capacity(n); + let mut nulls = BooleanBufferBuilder::new(n); + + // Notes for `Null` handling: + // - If the `count` state of a group is 0, no valid records are accumulated + // for this group, so the aggregation result is `Null`. + // - Correlation can't be calculated when a group only has 1 record, or when + // the `denominator` state is 0. In these cases, the final aggregation + // result should be `Null` (according to PostgreSQL's behavior). + // + // TODO: Old datafusion implementation returns 0.0 for these invalid cases. + // Update this to match PostgreSQL's behavior. + for i in 0..n { + if self.count[i] < 2 { + // TODO: Evaluate as `Null` (see notes above) + values.push(0.0); + nulls.append(false); + continue; + } + + let count = self.count[i]; + let sum_x = self.sum_x[i]; + let sum_y = self.sum_y[i]; + let sum_xy = self.sum_xy[i]; + let sum_xx = self.sum_xx[i]; + let sum_yy = self.sum_yy[i]; + + let mean_x = sum_x / count as f64; + let mean_y = sum_y / count as f64; + + let numerator = sum_xy - sum_x * mean_y; + let denominator = + ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt(); + + if denominator == 0.0 { + // TODO: Evaluate as `Null` (see notes above) + values.push(0.0); + nulls.append(false); + } else { + values.push(numerator / denominator); + nulls.append(true); + } + } + + Ok(Arc::new(Float64Array::new( + values.into(), + Some(nulls.finish().into()), + ))) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let n = match emit_to { + EmitTo::All => self.count.len(), + EmitTo::First(n) => n, + }; + + Ok(vec![ + Arc::new(UInt64Array::from(self.count[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())), + Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())), + ]) + } + + fn size(&self) -> usize { + size_of_val(&self.count) + + size_of_val(&self.sum_x) + + size_of_val(&self.sum_y) + + size_of_val(&self.sum_xy) + + size_of_val(&self.sum_xx) + + size_of_val(&self.sum_yy) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, UInt64Array}; + + #[test] + fn test_accumulate_correlation_states() { + // Test data + let group_indices = vec![0, 1, 0, 1]; + let counts = UInt64Array::from(vec![1, 2, 3, 4]); + let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]); + let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]); + let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]); + let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]); + let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]); + + let mut accumulated = vec![]; + accumulate_correlation_states( + &group_indices, + (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy), + |group_idx, count, values| { + accumulated.push((group_idx, count, values.to_vec())); + }, + ); + + let expected = vec![ + (0, 1, vec![10.0, 1.0, 10.0, 100.0, 1.0]), + (1, 2, vec![20.0, 2.0, 40.0, 400.0, 4.0]), + (0, 3, vec![30.0, 3.0, 90.0, 900.0, 9.0]), + (1, 4, vec![40.0, 4.0, 160.0, 1600.0, 16.0]), + ]; + assert_eq!(accumulated, expected); + + // Test that function panics with null values + let counts = UInt64Array::from(vec![Some(1), None, Some(3), Some(4)]); + let sum_x = Float64Array::from(vec![10.0, 20.0, 30.0, 40.0]); + let sum_y = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]); + let sum_xy = Float64Array::from(vec![10.0, 40.0, 90.0, 160.0]); + let sum_xx = Float64Array::from(vec![100.0, 400.0, 900.0, 1600.0]); + let sum_yy = Float64Array::from(vec![1.0, 4.0, 9.0, 16.0]); + + let result = std::panic::catch_unwind(|| { + accumulate_correlation_states( + &group_indices, + (&counts, &sum_x, &sum_y, &sum_xy, &sum_xx, &sum_yy), + |_, _, _| {}, + ) + }); + assert!(result.is_err()); + } +}