Skip to content

Commit

Permalink
Tidy up join test code (#13604)
Browse files Browse the repository at this point in the history
  • Loading branch information
ozankabak authored Nov 29, 2024
1 parent 189a4bb commit 55a0040
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 31 deletions.
44 changes: 22 additions & 22 deletions datafusion/core/src/physical_optimizer/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ fn supports_swap(join_type: JoinType) -> bool {

/// This function returns the new join type we get after swapping the given
/// join's inputs.
fn swap_join_type(join_type: JoinType) -> JoinType {
pub(crate) fn swap_join_type(join_type: JoinType) -> JoinType {
match join_type {
JoinType::Inner => JoinType::Inner,
JoinType::Full => JoinType::Full,
Expand Down Expand Up @@ -256,7 +256,7 @@ fn swap_nl_join(join: &NestedLoopJoinExec) -> Result<Arc<dyn ExecutionPlan>> {
/// the output should not be impacted. This function creates the expressions
/// that will allow to swap back the values from the original left as the first
/// columns and those on the right next.
fn swap_reverting_projection(
pub(crate) fn swap_reverting_projection(
left_schema: &Schema,
right_schema: &Schema,
) -> Vec<(Arc<dyn PhysicalExpr>, String)> {
Expand All @@ -278,7 +278,7 @@ fn swap_reverting_projection(
}

/// Swaps join sides for filter column indices and produces new JoinFilter
fn swap_filter(filter: &JoinFilter) -> JoinFilter {
pub(crate) fn swap_filter(filter: &JoinFilter) -> JoinFilter {
let column_indices = filter
.column_indices()
.iter()
Expand Down Expand Up @@ -736,6 +736,7 @@ mod tests_statistical {
use arrow::datatypes::{DataType, Field};
use datafusion_common::{stats::Precision, JoinType, ScalarValue};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr::expressions::BinaryExpr;
use datafusion_physical_expr::PhysicalExprRef;

Expand Down Expand Up @@ -1089,8 +1090,8 @@ mod tests_statistical {
Arc::clone(&big),
Arc::clone(&small),
vec![(
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()),
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()),
col("big_col", &big.schema()).unwrap(),
col("small_col", &small.schema()).unwrap(),
)],
None,
&JoinType::Inner,
Expand All @@ -1106,10 +1107,8 @@ mod tests_statistical {
Arc::clone(&medium),
Arc::new(child_join),
vec![(
Arc::new(
Column::new_with_schema("medium_col", &medium.schema()).unwrap(),
),
Arc::new(Column::new_with_schema("small_col", &child_schema).unwrap()),
col("medium_col", &medium.schema()).unwrap(),
col("small_col", &child_schema).unwrap(),
)],
None,
&JoinType::Left,
Expand Down Expand Up @@ -1421,8 +1420,8 @@ mod tests_statistical {
));

let join_on = vec![(
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
col("small_col", &small.schema()).unwrap(),
col("big_col", &big.schema()).unwrap(),
)];
check_join_partition_mode(
small.clone(),
Expand All @@ -1433,8 +1432,8 @@ mod tests_statistical {
);

let join_on = vec![(
Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
col("big_col", &big.schema()).unwrap(),
col("small_col", &small.schema()).unwrap(),
)];
check_join_partition_mode(
big,
Expand All @@ -1445,8 +1444,8 @@ mod tests_statistical {
);

let join_on = vec![(
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
col("small_col", &small.schema()).unwrap(),
col("empty_col", &empty.schema()).unwrap(),
)];
check_join_partition_mode(
small.clone(),
Expand All @@ -1457,8 +1456,8 @@ mod tests_statistical {
);

let join_on = vec![(
Arc::new(Column::new_with_schema("empty_col", &empty.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _,
col("empty_col", &empty.schema()).unwrap(),
col("small_col", &small.schema()).unwrap(),
)];
check_join_partition_mode(
empty,
Expand Down Expand Up @@ -1627,6 +1626,7 @@ mod hash_join_tests {

use arrow::datatypes::{DataType, Field};
use arrow::record_batch::RecordBatch;
use datafusion_physical_expr::expressions::col;

struct TestCase {
case: String,
Expand Down Expand Up @@ -1969,7 +1969,7 @@ mod hash_join_tests {
false,
)]))),
2,
)) as Arc<dyn ExecutionPlan>;
)) as _;
let right_exec = Arc::new(UnboundedExec::new(
(!right_unbounded).then_some(1),
RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new(
Expand All @@ -1978,21 +1978,21 @@ mod hash_join_tests {
false,
)]))),
2,
)) as Arc<dyn ExecutionPlan>;
)) as _;

let join = Arc::new(HashJoinExec::try_new(
Arc::clone(&left_exec),
Arc::clone(&right_exec),
vec![(
Arc::new(Column::new_with_schema("a", &left_exec.schema())?),
Arc::new(Column::new_with_schema("b", &right_exec.schema())?),
col("a", &left_exec.schema())?,
col("b", &right_exec.schema())?,
)],
None,
&t.initial_join_type,
None,
t.initial_mode,
false,
)?);
)?) as _;

let optimized_join_plan = hash_join_swap_subrule(join, &ConfigOptions::new())?;

Expand Down
20 changes: 11 additions & 9 deletions datafusion/physical-plan/src/joins/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,23 @@ use rand::prelude::StdRng;
use rand::{Rng, SeedableRng};

pub fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) {
let left_row_num: usize = collected_1.iter().map(|batch| batch.num_rows()).sum();
let right_row_num: usize = collected_2.iter().map(|batch| batch.num_rows()).sum();
if left_row_num == 0 && right_row_num == 0 {
return;
}
// compare
let first_formatted = pretty_format_batches(collected_1).unwrap().to_string();
let second_formatted = pretty_format_batches(collected_2).unwrap().to_string();

let mut first_formatted_sorted: Vec<&str> = first_formatted.trim().lines().collect();
first_formatted_sorted.sort_unstable();
let mut first_lines: Vec<&str> = first_formatted.trim().lines().collect();
first_lines.sort_unstable();

let mut second_formatted_sorted: Vec<&str> =
second_formatted.trim().lines().collect();
second_formatted_sorted.sort_unstable();
let mut second_lines: Vec<&str> = second_formatted.trim().lines().collect();
second_lines.sort_unstable();

for (i, (first_line, second_line)) in first_formatted_sorted
.iter()
.zip(&second_formatted_sorted)
.enumerate()
for (i, (first_line, second_line)) in
first_lines.iter().zip(&second_lines).enumerate()
{
assert_eq!((i, first_line), (i, second_line));
}
Expand Down

0 comments on commit 55a0040

Please sign in to comment.