Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(query): join predict use cast_expr_to_non_null_boolean #16937

Merged
merged 3 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 31 additions & 44 deletions src/common/column/src/buffer/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,20 @@ pub struct Buffer<T> {
/// the internal byte buffer.
data: Arc<Bytes<T>>,

/// The offset into the buffer.
offset: usize,
/// Pointer into `data` valid
///
/// We store a pointer instead of an offset to avoid pointer arithmetic
/// which causes LLVM to fail to vectorise code correctly
ptr: *const T,

// the length of the buffer. Given a region `data` of N bytes, [offset..offset+length] is visible
// to this buffer.
length: usize,
}

unsafe impl<T: Send> Send for Buffer<T> {}
unsafe impl<T: Sync> Sync for Buffer<T> {}

impl<T: PartialEq> PartialEq for Buffer<T> {
#[inline]
fn eq(&self, other: &Self) -> bool {
Expand Down Expand Up @@ -101,9 +107,10 @@ impl<T> Buffer<T> {
/// Auxiliary method to create a new Buffer
pub(crate) fn from_bytes(bytes: Bytes<T>) -> Self {
let length = bytes.len();
let ptr = bytes.as_ptr();
Buffer {
data: Arc::new(bytes),
offset: 0,
ptr,
length,
}
}
Expand All @@ -130,24 +137,7 @@ impl<T> Buffer<T> {
/// Returns the byte slice stored in this buffer
#[inline]
pub fn as_slice(&self) -> &[T] {
// Safety:
// invariant of this struct `offset + length <= data.len()`
debug_assert!(self.offset + self.length <= self.data.len());
unsafe {
self.data
.get_unchecked(self.offset..self.offset + self.length)
}
}

/// Returns the byte slice stored in this buffer
/// # Safety
/// `index` must be smaller than `len`
#[inline]
pub(super) unsafe fn get_unchecked(&self, index: usize) -> &T {
// Safety:
// invariant of this function
debug_assert!(index < self.length);
unsafe { self.data.get_unchecked(self.offset + index) }
self
}

/// Returns a new [`Buffer`] that is a slice of this buffer starting at `offset`.
Expand Down Expand Up @@ -193,20 +183,20 @@ impl<T> Buffer<T> {
/// The caller must ensure `offset + length <= self.len()`
#[inline]
pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {
self.offset += offset;
self.ptr = self.ptr.add(offset);
self.length = length;
}

/// Returns a pointer to the start of this buffer.
#[inline]
pub(crate) fn data_ptr(&self) -> *const T {
self.data.deref().as_ptr()
self.data.as_ptr()
}

/// Returns the offset of this buffer.
#[inline]
pub fn offset(&self) -> usize {
self.offset
unsafe { self.ptr.offset_from(self.data_ptr()) as usize }
}

/// # Safety
Expand Down Expand Up @@ -253,10 +243,11 @@ impl<T> Buffer<T> {
/// * has not been imported from the c data interface (FFI)
#[inline]
pub fn get_mut_slice(&mut self) -> Option<&mut [T]> {
let offset = self.offset();
Arc::get_mut(&mut self.data)
.and_then(|b| b.get_vec())
// Safety: the invariant of this struct
.map(|x| unsafe { x.get_unchecked_mut(self.offset..self.offset + self.length) })
.map(|x| unsafe { x.get_unchecked_mut(offset..offset + self.length) })
}

/// Get the strong count of underlying `Arc` data buffer.
Expand All @@ -269,28 +260,14 @@ impl<T> Buffer<T> {
Arc::weak_count(&self.data)
}

/// Returns its internal representation
#[must_use]
pub fn into_inner(self) -> (Arc<Bytes<T>>, usize, usize) {
let Self {
data,
offset,
length,
} = self;
(data, offset, length)
}

/// Creates a `[Bitmap]` from its internal representation.
/// This is the inverted from `[Bitmap::into_inner]`
///
/// # Safety
/// Callers must ensure all invariants of this struct are upheld.
pub unsafe fn from_inner_unchecked(data: Arc<Bytes<T>>, offset: usize, length: usize) -> Self {
Self {
data,
offset,
length,
}
let ptr = data.as_ptr().add(offset);
Self { data, ptr, length }
}
}

Expand All @@ -313,8 +290,9 @@ impl<T> From<Vec<T>> for Buffer<T> {
#[inline]
fn from(p: Vec<T>) -> Self {
let bytes: Bytes<T> = p.into();
let ptr = bytes.as_ptr();
Self {
offset: 0,
ptr,
length: bytes.len(),
data: Arc::new(bytes),
}
Expand All @@ -326,7 +304,15 @@ impl<T> std::ops::Deref for Buffer<T> {

#[inline]
fn deref(&self) -> &[T] {
self.as_slice()
debug_assert!(self.offset() + self.length <= self.data.len());
unsafe { std::slice::from_raw_parts(self.ptr, self.length) }
}
}

impl<T> AsRef<[T]> for Buffer<T> {
#[inline]
fn as_ref(&self) -> &[T] {
self
}
}

Expand Down Expand Up @@ -375,8 +361,9 @@ impl<T: crate::types::NativeType> From<arrow_buffer::Buffer> for Buffer<T> {

impl<T: crate::types::NativeType> From<Buffer<T>> for arrow_buffer::Buffer {
fn from(value: Buffer<T>) -> Self {
let offset = value.offset();
crate::buffer::to_buffer(value.data).slice_with_length(
value.offset * std::mem::size_of::<T>(),
offset * std::mem::size_of::<T>(),
value.length * std::mem::size_of::<T>(),
)
}
Expand Down
3 changes: 3 additions & 0 deletions src/common/column/tests/it/buffer/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ fn from_slice() {
let buffer = Buffer::<i32>::from(vec![0, 1, 2]);
assert_eq!(buffer.len(), 3);
assert_eq!(buffer.as_slice(), &[0, 1, 2]);

assert_eq!(unsafe { *buffer.get_unchecked(1) }, 1);
assert_eq!(unsafe { *buffer.get_unchecked(2) }, 2);
}

#[test]
Expand Down
17 changes: 17 additions & 0 deletions src/query/expression/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#[macro_use]
extern crate criterion;

use arrow_buffer::ScalarBuffer;
use criterion::Criterion;
use databend_common_column::buffer::Buffer;
use databend_common_expression::arrow::deserialize_column;
Expand Down Expand Up @@ -135,6 +136,9 @@ fn bench(c: &mut Criterion) {
for length in [10240, 102400] {
let (left, right) = generate_random_int_data(&mut rng, length);

let left_scalar = ScalarBuffer::from_iter(left.iter().cloned());
let right_scalar = ScalarBuffer::from_iter(right.iter().cloned());

group.bench_function(format!("function_iterator_iterator_v1/{length}"), |b| {
b.iter(|| {
let left = left.clone();
Expand Down Expand Up @@ -170,6 +174,19 @@ fn bench(c: &mut Criterion) {
},
);

group.bench_function(
format!("function_buffer_scalar_index_unchecked_iterator/{length}"),
|b| {
b.iter(|| {
let _c = (0..length)
.map(|i| unsafe {
left_scalar.get_unchecked(i) + right_scalar.get_unchecked(i)
})
.collect::<Vec<i32>>();
})
},
);

group.bench_function(
format!("function_slice_index_unchecked_iterator/{length}"),
|b| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use databend_common_expression::type_check::check_function;
use databend_common_expression::Expr;
use databend_common_expression::RemoteExpr;
use databend_common_functions::BUILTIN_FUNCTIONS;
use databend_common_sql::executor::cast_expr_to_non_null_boolean;
use databend_common_sql::executor::physical_plans::HashJoin;
use databend_common_sql::IndexType;
use parking_lot::RwLock;
Expand Down Expand Up @@ -96,11 +97,21 @@ impl HashJoinDesc {
}

fn join_predicate(non_equi_conditions: &[RemoteExpr]) -> Result<Option<Expr>> {
non_equi_conditions
let expr = non_equi_conditions
.iter()
.map(|expr| expr.as_expr(&BUILTIN_FUNCTIONS))
.try_reduce(|lhs, rhs| {
check_function(None, "and_filters", &[], &[lhs, rhs], &BUILTIN_FUNCTIONS)
})
});
// For RIGHT MARK join, we can't use is_true to cast filter into non_null boolean
match expr {
Ok(Some(expr)) => match expr {
Expr::Constant { ref scalar, .. } if !scalar.is_null() => {
Ok(Some(cast_expr_to_non_null_boolean(expr)?))
}
_ => Ok(Some(expr)),
},
other => other,
}
}
}
30 changes: 18 additions & 12 deletions tests/sqllogictests/suites/query/join/join.test
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ statement ok
drop table if exists t1;

statement ok
create table t1 (a int);
create or replace table t1 (a int);

# right join with empty build side
query II
Expand Down Expand Up @@ -82,7 +82,7 @@ statement ok
drop table if exists t1;

statement ok
create table t1(a int, b int)
create or replace table t1(a int, b int)

statement ok
insert into t1 values(1, 2), (2, 4), (3, 6), (4, 8), (5, 10)
Expand All @@ -91,7 +91,7 @@ statement ok
drop table if exists t2

statement ok
create table t2(a int, b int)
create or replace table t2(a int, b int)

statement ok
insert into t2 values(1, 2), (1, 4), (1, 6), (1, 8), (1, 10);
Expand Down Expand Up @@ -124,10 +124,10 @@ statement ok
drop table if exists t2;

statement ok
create table t1 (id int, val bigint unsigned default 0);
create or replace table t1 (id int, val bigint unsigned default 0);

statement ok
create table t2 (id int, val bigint unsigned default 0);
create or replace table t2 (id int, val bigint unsigned default 0);

statement ok
insert into t1 values(1, 1696549154011), (2, 1696549154013);
Expand All @@ -153,13 +153,13 @@ statement ok
drop table t2;

statement ok
create table t(id int);
create or replace table t(id int);

statement ok
insert into t values(1), (2);

statement ok
create table t1(id int, col1 varchar);
create or replace table t1(id int, col1 varchar);

statement ok
insert into t1 values(1, 'c'), (3, 'd');
Expand Down Expand Up @@ -203,13 +203,13 @@ statement ok
drop table if exists t2;

statement ok
create table t1(a int, b int);
create or replace table t1(a int, b int);

statement ok
insert into t1 values(1, 1),(2, 2),(3, 3);

statement ok
create table t2(a int, b int);
create or replace table t2(a int, b int);

statement ok
insert into t2 values(1, 1),(2, 2);
Expand Down Expand Up @@ -237,13 +237,13 @@ statement ok
drop table if exists t2;

statement ok
create table t1(a int, b int, c int, d int);
create or replace table t1(a int, b int, c int, d int);

statement ok
insert into t1 values(1, 2, 3, 4);

statement ok
create table t2(a int, b int, c int, d int);
create or replace table t2(a int, b int, c int, d int);

statement ok
insert into t2 values(1, 2, 3, 4);
Expand All @@ -255,7 +255,7 @@ statement ok
drop table if exists t;

statement ok
create table t(a int);
create or replace table t(a int);

statement ok
insert into t values(1),(2),(3);
Expand All @@ -274,5 +274,11 @@ select * from (select number from numbers(5)) t2 full outer join (select a, 'A'
2 2 A
3 3 A

statement ok
select * from (select number from numbers(5)) t2 full outer join (select a, 'A' as name from t) t1 on t1.a = t2.number and 123;

statement error
select * from (select number from numbers(5)) t2 full outer join (select a, 'A' as name from t) t1 on t1.a = t2.number and 11981933213501947393::DATE;

statement ok
drop table if exists t;
Loading