From be70e939333f2a2ec2cc9a3fd1123f58fe596d8e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 2 Dec 2024 14:46:39 -0500 Subject: [PATCH] update API --- .../functions-aggregate/src/array_agg.rs | 9 +++---- .../functions-aggregate/src/nth_value.rs | 6 ++--- datafusion/functions-nested/src/make_array.rs | 7 +++-- datafusion/functions-nested/src/utils.rs | 27 ++++++++++--------- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 3b9a521ec972..39f8c88338b1 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -22,7 +22,7 @@ use arrow::datatypes::DataType; use arrow_schema::{Field, Fields}; use datafusion_common::cast::as_list_array; -use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; +use datafusion_common::utils::{get_row_at_idx, SingleRowArrayBuilder}; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{internal_err, Result}; use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; @@ -238,9 +238,8 @@ impl Accumulator for ArrayAggAccumulator { } let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array_nullable(concated_array); - Ok(ScalarValue::List(Arc::new(list_array))) + Ok(SingleRowArrayBuilder::new(concated_array).build_list_scalar()) } fn size(&self) -> usize { @@ -530,9 +529,7 @@ impl OrderSensitiveArrayAggAccumulator { let ordering_array = StructArray::try_new(struct_field, column_wise_ordering_values, None)?; - Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( - Arc::new(ordering_array), - )))) + Ok(SingleRowArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) } } diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 0c72939633b1..a882a3c8e9d1 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -26,7 +26,7 @@ use std::sync::{Arc, OnceLock}; use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; use arrow_schema::{DataType, Field, Fields}; -use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; +use datafusion_common::utils::{get_row_at_idx, SingleRowArrayBuilder}; use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -423,9 +423,7 @@ impl NthValueAccumulator { let ordering_array = StructArray::try_new(struct_field, column_wise_ordering_values, None)?; - Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( - Arc::new(ordering_array), - )))) + Ok(SingleRowArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) } fn evaluate_values(&self) -> ScalarValue { diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 825824a82d20..1027a20c3898 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -28,7 +28,8 @@ use arrow_array::{ use arrow_buffer::OffsetBuffer; use arrow_schema::DataType::{List, Null}; use arrow_schema::{DataType, Field}; -use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result}; +use datafusion_common::utils::SingleRowArrayBuilder; +use datafusion_common::{plan_err, Result}; use datafusion_expr::binary::{ try_type_union_resolution_with_struct, type_union_resolution, }; @@ -194,7 +195,9 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { let length = arrays.iter().map(|a| a.len()).sum(); // By default Int64 let array = new_null_array(&DataType::Int64, length); - Ok(Arc::new(array_into_list_array_nullable(array))) + Ok(Arc::new( + SingleRowArrayBuilder::new(array).build_list_array(), + )) } _ => array_array::(arrays, data_type), } diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index b9a75724bcde..8381ac7a01e5 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -274,27 +274,27 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { mod tests { use super::*; use arrow::datatypes::Int64Type; - use datafusion_common::utils::array_into_list_array_nullable; + use datafusion_common::utils::SingleRowArrayBuilder; /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt` #[test] fn test_align_array_dimensions() { - let array1d_1 = + let array1d_1: ArrayRef = Arc::new(ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), Some(vec![Some(4), Some(5)]), ])); - let array1d_2 = + let array1d_2: ArrayRef = Arc::new(ListArray::from_iter_primitive::(vec![ Some(vec![Some(6), Some(7), Some(8)]), ])); - let array2d_1 = Arc::new(array_into_list_array_nullable( - Arc::clone(&array1d_1) as ArrayRef - )) as ArrayRef; - let array2d_2 = Arc::new(array_into_list_array_nullable( - Arc::clone(&array1d_2) as ArrayRef - )) as ArrayRef; + let array2d_1: ArrayRef = Arc::new( + SingleRowArrayBuilder::new(Arc::clone(&array1d_1)).build_list_array(), + ); + let array2d_2 = Arc::new( + SingleRowArrayBuilder::new(Arc::clone(&array1d_2)).build_list_array(), + ); let res = align_array_dimensions::(vec![ array1d_1.to_owned(), @@ -310,10 +310,11 @@ mod tests { expected_dim ); - let array3d_1 = Arc::new(array_into_list_array_nullable(array2d_1)) as ArrayRef; - let array3d_2 = array_into_list_array_nullable(array2d_2.to_owned()); - let res = - align_array_dimensions::(vec![array1d_1, Arc::new(array3d_2)]).unwrap(); + let array3d_1: ArrayRef = + Arc::new(SingleRowArrayBuilder::new(array2d_1).build_list_array()); + let array3d_2: ArrayRef = + Arc::new(SingleRowArrayBuilder::new(array2d_2).build_list_array()); + let res = align_array_dimensions::(vec![array1d_1, array3d_2]).unwrap(); let expected = as_list_array(&array3d_1).unwrap(); let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type());