Skip to content

Commit

Permalink
Implement GroupsAccumulator for corr(x,y)
Browse files Browse the repository at this point in the history
  • Loading branch information
2010YOUY01 committed Nov 27, 2024
1 parent 2e05648 commit a834fda
Show file tree
Hide file tree
Showing 2 changed files with 513 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,87 @@ pub fn accumulate<T, F>(
}
}

/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`)
///
/// This method assumes that for any input record index, if any of the value column
/// is null, or it's filtered out by `opt_filter`, then the record would be ignored.
/// (won't be accumulated by `value_fn`)
pub fn accumulate_multiple<T, F>(
group_indices: &[usize],
value_columns: &[&PrimitiveArray<T>],
opt_filter: Option<&BooleanArray>,
mut value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, &[T::Native]) + Send,
{
let acc_cols: Vec<&[T::Native]> = value_columns
.iter()
.map(|arr| arr.values().as_ref())
.collect();

// Calculate `valid_indices` to accumulate, non-valid indices are ignored.
// `valid_indices` is a bit mask corresponding to the `group_indices`. An index
// is considered valid if:
// 1. All columns are non-null at this index.
// 2. Not filtered out by `opt_filter`

// Take AND from all null buffers of `value_columns`.
let mut combined_nulls: Option<NullBuffer> = None;

for arr in value_columns.iter() {
if arr.null_count() > 0 {
let nulls = arr
.nulls()
.expect("If null_count() > 0, nulls must be present");
match combined_nulls {
None => {
combined_nulls = Some(nulls.clone());
}
Some(ref mut combined) => {
let result = NullBuffer::union(Some(combined), Some(nulls)).unwrap();
*combined = result.clone();
}
}
}
}

// Take AND from previous combined nulls and `opt_filter`.
let valid_indices = match (combined_nulls, opt_filter) {
(None, None) => None,
(None, Some(filter)) => Some(filter.clone()),
(Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)),
(Some(nulls), Some(filter)) => {
let combined = nulls.inner() & filter.values();
Some(BooleanArray::new(combined, None))
}
};

for col in acc_cols.iter() {
assert_eq!(col.len(), group_indices.len());
}

match valid_indices {
None => {
for (idx, &group_idx) in group_indices.iter().enumerate() {
// Get `idx`-th row from all value(accumulate) columns
let row_values: Vec<_> = acc_cols.iter().map(|col| col[idx]).collect();
value_fn(group_idx, &row_values);
}
}
Some(valid_indices) => {
for (idx, &group_idx) in group_indices.iter().enumerate() {
if valid_indices.value(idx) {
// Get `idx`-th row from all value(accumulate) columns
let row_values: Vec<_> =
acc_cols.iter().map(|col| col[idx]).collect();
value_fn(group_idx, &row_values);
}
}
}
}
}

/// This function is called to update the accumulator state per row
/// when the value is not needed (e.g. COUNT)
///
Expand Down Expand Up @@ -528,7 +609,7 @@ fn initialize_builder(
mod test {
use super::*;

use arrow::array::UInt32Array;
use arrow::array::{Int32Array, UInt32Array};
use rand::{rngs::ThreadRng, Rng};
use std::collections::HashSet;

Expand Down Expand Up @@ -940,4 +1021,103 @@ mod test {
.collect()
}
}

#[test]
fn test_accumulate_multiple_no_nulls_no_filter() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![1, 2, 3, 4]);
let values2 = Int32Array::from(vec![10, 20, 30, 40]);
let value_columns = vec![values1, values2];

let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
None,
|group_idx, values| {
accumulated.push((group_idx, values.to_vec()));
},
);

let expected = vec![
(0, vec![1, 10]),
(1, vec![2, 20]),
(0, vec![3, 30]),
(1, vec![4, 40]),
];
assert_eq!(accumulated, expected);
}

#[test]
fn test_accumulate_multiple_with_nulls() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
let value_columns = vec![values1, values2];

let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
None,
|group_idx, values| {
accumulated.push((group_idx, values.to_vec()));
},
);

// Only rows where both columns are non-null should be accumulated
let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])];
assert_eq!(accumulated, expected);
}

#[test]
fn test_accumulate_multiple_with_filter() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![1, 2, 3, 4]);
let values2 = Int32Array::from(vec![10, 20, 30, 40]);
let value_columns = vec![values1, values2];

let filter = BooleanArray::from(vec![true, false, true, false]);

let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
Some(&filter),
|group_idx, values| {
accumulated.push((group_idx, values.to_vec()));
},
);

// Only rows where filter is true should be accumulated
let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
assert_eq!(accumulated, expected);
}

#[test]
fn test_accumulate_multiple_with_nulls_and_filter() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
let value_columns = vec![values1, values2];

let filter = BooleanArray::from(vec![true, true, true, false]);

let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
Some(&filter),
|group_idx, values| {
accumulated.push((group_idx, values.to_vec()));
},
);

// Only rows where both:
// 1. Filter is true
// 2. Both columns are non-null
// should be accumulated
let expected = vec![(0, vec![1, 10])];
assert_eq!(accumulated, expected);
}
}
Loading

0 comments on commit a834fda

Please sign in to comment.