Skip to content

Commit

Permalink
fix(query): join predict use cast_expr_to_non_null_boolean (#16937)
Browse files Browse the repository at this point in the history
* fix(query): join predict use cast_expr_to_non_null_boolean

* update

* update
  • Loading branch information
sundy-li authored Nov 25, 2024
1 parent 682039f commit 04df094
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 58 deletions.
75 changes: 31 additions & 44 deletions src/common/column/src/buffer/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,20 @@ pub struct Buffer<T> {
/// the internal byte buffer.
data: Arc<Bytes<T>>,

/// The offset into the buffer.
offset: usize,
/// Pointer into `data` valid
///
/// We store a pointer instead of an offset to avoid pointer arithmetic
/// which causes LLVM to fail to vectorise code correctly
ptr: *const T,

// the length of the buffer. Given a region `data` of N bytes, [offset..offset+length] is visible
// to this buffer.
length: usize,
}

unsafe impl<T: Send> Send for Buffer<T> {}
unsafe impl<T: Sync> Sync for Buffer<T> {}

impl<T: PartialEq> PartialEq for Buffer<T> {
#[inline]
fn eq(&self, other: &Self) -> bool {
Expand Down Expand Up @@ -101,9 +107,10 @@ impl<T> Buffer<T> {
/// Auxiliary method to create a new Buffer
pub(crate) fn from_bytes(bytes: Bytes<T>) -> Self {
let length = bytes.len();
let ptr = bytes.as_ptr();
Buffer {
data: Arc::new(bytes),
offset: 0,
ptr,
length,
}
}
Expand All @@ -130,24 +137,7 @@ impl<T> Buffer<T> {
/// Returns the byte slice stored in this buffer
#[inline]
pub fn as_slice(&self) -> &[T] {
// Safety:
// invariant of this struct `offset + length <= data.len()`
debug_assert!(self.offset + self.length <= self.data.len());
unsafe {
self.data
.get_unchecked(self.offset..self.offset + self.length)
}
}

/// Returns the byte slice stored in this buffer
/// # Safety
/// `index` must be smaller than `len`
#[inline]
pub(super) unsafe fn get_unchecked(&self, index: usize) -> &T {
// Safety:
// invariant of this function
debug_assert!(index < self.length);
unsafe { self.data.get_unchecked(self.offset + index) }
self
}

/// Returns a new [`Buffer`] that is a slice of this buffer starting at `offset`.
Expand Down Expand Up @@ -193,20 +183,20 @@ impl<T> Buffer<T> {
/// The caller must ensure `offset + length <= self.len()`
#[inline]
pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {
self.offset += offset;
self.ptr = self.ptr.add(offset);
self.length = length;
}

/// Returns a pointer to the start of this buffer.
#[inline]
pub(crate) fn data_ptr(&self) -> *const T {
self.data.deref().as_ptr()
self.data.as_ptr()
}

/// Returns the offset of this buffer.
#[inline]
pub fn offset(&self) -> usize {
self.offset
unsafe { self.ptr.offset_from(self.data_ptr()) as usize }
}

/// # Safety
Expand Down Expand Up @@ -253,10 +243,11 @@ impl<T> Buffer<T> {
/// * has not been imported from the c data interface (FFI)
#[inline]
pub fn get_mut_slice(&mut self) -> Option<&mut [T]> {
let offset = self.offset();
Arc::get_mut(&mut self.data)
.and_then(|b| b.get_vec())
// Safety: the invariant of this struct
.map(|x| unsafe { x.get_unchecked_mut(self.offset..self.offset + self.length) })
.map(|x| unsafe { x.get_unchecked_mut(offset..offset + self.length) })
}

/// Get the strong count of underlying `Arc` data buffer.
Expand All @@ -269,28 +260,14 @@ impl<T> Buffer<T> {
Arc::weak_count(&self.data)
}

/// Returns its internal representation
#[must_use]
pub fn into_inner(self) -> (Arc<Bytes<T>>, usize, usize) {
let Self {
data,
offset,
length,
} = self;
(data, offset, length)
}

/// Creates a `[Bitmap]` from its internal representation.
/// This is the inverted from `[Bitmap::into_inner]`
///
/// # Safety
/// Callers must ensure all invariants of this struct are upheld.
pub unsafe fn from_inner_unchecked(data: Arc<Bytes<T>>, offset: usize, length: usize) -> Self {
Self {
data,
offset,
length,
}
let ptr = data.as_ptr().add(offset);
Self { data, ptr, length }
}
}

Expand All @@ -313,8 +290,9 @@ impl<T> From<Vec<T>> for Buffer<T> {
#[inline]
fn from(p: Vec<T>) -> Self {
let bytes: Bytes<T> = p.into();
let ptr = bytes.as_ptr();
Self {
offset: 0,
ptr,
length: bytes.len(),
data: Arc::new(bytes),
}
Expand All @@ -326,7 +304,15 @@ impl<T> std::ops::Deref for Buffer<T> {

#[inline]
fn deref(&self) -> &[T] {
self.as_slice()
debug_assert!(self.offset() + self.length <= self.data.len());
unsafe { std::slice::from_raw_parts(self.ptr, self.length) }
}
}

impl<T> AsRef<[T]> for Buffer<T> {
#[inline]
fn as_ref(&self) -> &[T] {
self
}
}

Expand Down Expand Up @@ -375,8 +361,9 @@ impl<T: crate::types::NativeType> From<arrow_buffer::Buffer> for Buffer<T> {

impl<T: crate::types::NativeType> From<Buffer<T>> for arrow_buffer::Buffer {
fn from(value: Buffer<T>) -> Self {
let offset = value.offset();
crate::buffer::to_buffer(value.data).slice_with_length(
value.offset * std::mem::size_of::<T>(),
offset * std::mem::size_of::<T>(),
value.length * std::mem::size_of::<T>(),
)
}
Expand Down
3 changes: 3 additions & 0 deletions src/common/column/tests/it/buffer/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ fn from_slice() {
let buffer = Buffer::<i32>::from(vec![0, 1, 2]);
assert_eq!(buffer.len(), 3);
assert_eq!(buffer.as_slice(), &[0, 1, 2]);

assert_eq!(unsafe { *buffer.get_unchecked(1) }, 1);
assert_eq!(unsafe { *buffer.get_unchecked(2) }, 2);
}

#[test]
Expand Down
17 changes: 17 additions & 0 deletions src/query/expression/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#[macro_use]
extern crate criterion;

use arrow_buffer::ScalarBuffer;
use criterion::Criterion;
use databend_common_column::buffer::Buffer;
use databend_common_expression::arrow::deserialize_column;
Expand Down Expand Up @@ -135,6 +136,9 @@ fn bench(c: &mut Criterion) {
for length in [10240, 102400] {
let (left, right) = generate_random_int_data(&mut rng, length);

let left_scalar = ScalarBuffer::from_iter(left.iter().cloned());
let right_scalar = ScalarBuffer::from_iter(right.iter().cloned());

group.bench_function(format!("function_iterator_iterator_v1/{length}"), |b| {
b.iter(|| {
let left = left.clone();
Expand Down Expand Up @@ -170,6 +174,19 @@ fn bench(c: &mut Criterion) {
},
);

group.bench_function(
format!("function_buffer_scalar_index_unchecked_iterator/{length}"),
|b| {
b.iter(|| {
let _c = (0..length)
.map(|i| unsafe {
left_scalar.get_unchecked(i) + right_scalar.get_unchecked(i)
})
.collect::<Vec<i32>>();
})
},
);

group.bench_function(
format!("function_slice_index_unchecked_iterator/{length}"),
|b| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use databend_common_expression::type_check::check_function;
use databend_common_expression::Expr;
use databend_common_expression::RemoteExpr;
use databend_common_functions::BUILTIN_FUNCTIONS;
use databend_common_sql::executor::cast_expr_to_non_null_boolean;
use databend_common_sql::executor::physical_plans::HashJoin;
use databend_common_sql::IndexType;
use parking_lot::RwLock;
Expand Down Expand Up @@ -96,11 +97,21 @@ impl HashJoinDesc {
}

fn join_predicate(non_equi_conditions: &[RemoteExpr]) -> Result<Option<Expr>> {
non_equi_conditions
let expr = non_equi_conditions
.iter()
.map(|expr| expr.as_expr(&BUILTIN_FUNCTIONS))
.try_reduce(|lhs, rhs| {
check_function(None, "and_filters", &[], &[lhs, rhs], &BUILTIN_FUNCTIONS)
})
});
// For RIGHT MARK join, we can't use is_true to cast filter into non_null boolean
match expr {
Ok(Some(expr)) => match expr {
Expr::Constant { ref scalar, .. } if !scalar.is_null() => {
Ok(Some(cast_expr_to_non_null_boolean(expr)?))
}
_ => Ok(Some(expr)),
},
other => other,
}
}
}
30 changes: 18 additions & 12 deletions tests/sqllogictests/suites/query/join/join.test
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ statement ok
drop table if exists t1;

statement ok
create table t1 (a int);
create or replace table t1 (a int);

# right join with empty build side
query II
Expand Down Expand Up @@ -82,7 +82,7 @@ statement ok
drop table if exists t1;

statement ok
create table t1(a int, b int)
create or replace table t1(a int, b int)

statement ok
insert into t1 values(1, 2), (2, 4), (3, 6), (4, 8), (5, 10)
Expand All @@ -91,7 +91,7 @@ statement ok
drop table if exists t2

statement ok
create table t2(a int, b int)
create or replace table t2(a int, b int)

statement ok
insert into t2 values(1, 2), (1, 4), (1, 6), (1, 8), (1, 10);
Expand Down Expand Up @@ -124,10 +124,10 @@ statement ok
drop table if exists t2;

statement ok
create table t1 (id int, val bigint unsigned default 0);
create or replace table t1 (id int, val bigint unsigned default 0);

statement ok
create table t2 (id int, val bigint unsigned default 0);
create or replace table t2 (id int, val bigint unsigned default 0);

statement ok
insert into t1 values(1, 1696549154011), (2, 1696549154013);
Expand All @@ -153,13 +153,13 @@ statement ok
drop table t2;

statement ok
create table t(id int);
create or replace table t(id int);

statement ok
insert into t values(1), (2);

statement ok
create table t1(id int, col1 varchar);
create or replace table t1(id int, col1 varchar);

statement ok
insert into t1 values(1, 'c'), (3, 'd');
Expand Down Expand Up @@ -203,13 +203,13 @@ statement ok
drop table if exists t2;

statement ok
create table t1(a int, b int);
create or replace table t1(a int, b int);

statement ok
insert into t1 values(1, 1),(2, 2),(3, 3);

statement ok
create table t2(a int, b int);
create or replace table t2(a int, b int);

statement ok
insert into t2 values(1, 1),(2, 2);
Expand Down Expand Up @@ -237,13 +237,13 @@ statement ok
drop table if exists t2;

statement ok
create table t1(a int, b int, c int, d int);
create or replace table t1(a int, b int, c int, d int);

statement ok
insert into t1 values(1, 2, 3, 4);

statement ok
create table t2(a int, b int, c int, d int);
create or replace table t2(a int, b int, c int, d int);

statement ok
insert into t2 values(1, 2, 3, 4);
Expand All @@ -255,7 +255,7 @@ statement ok
drop table if exists t;

statement ok
create table t(a int);
create or replace table t(a int);

statement ok
insert into t values(1),(2),(3);
Expand All @@ -274,5 +274,11 @@ select * from (select number from numbers(5)) t2 full outer join (select a, 'A'
2 2 A
3 3 A

statement ok
select * from (select number from numbers(5)) t2 full outer join (select a, 'A' as name from t) t1 on t1.a = t2.number and 123;

statement error
select * from (select number from numbers(5)) t2 full outer join (select a, 'A' as name from t) t1 on t1.a = t2.number and 11981933213501947393::DATE;

statement ok
drop table if exists t;

0 comments on commit 04df094

Please sign in to comment.