Skip to content

Commit

Permalink
Merge pull request #2046 from mabel-dev/#2045
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Oct 2, 2024
2 parents ac7b40c + e21332f commit 84e3e9c
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 47 deletions.
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 814
__build__ = 816

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
194 changes: 159 additions & 35 deletions opteryx/compiled/structures/hash_table.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ cdef class HashTable:

def __init__(self):
self.hash_table = unordered_map[int64_t, vector[int64_t]]()
self.hash_table.reserve(1_048_576) # try to prevent needing to resize

cpdef bint insert(self, int64_t key, int64_t row_id):
# If the key is already in the hash table, append the row_id to the existing list.
Expand Down Expand Up @@ -57,16 +58,16 @@ cdef class HashSet:
cdef inline bint contains(self, int64_t value):
return self.c_set.find(value) != self.c_set.end()

@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline object recast_column(column):
cdef column_type = column.type

if pyarrow.types.is_struct(column_type) or pyarrow.types.is_list(column_type):
return numpy.array([str(a) for a in column.to_pylist()], dtype=numpy.str_)
return numpy.array([str(a) for a in column], dtype=numpy.str_)

# Otherwise, return the column as-is
return column

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef tuple distinct(table, HashSet seen_hashes=None, list columns=None):
"""
Expand Down Expand Up @@ -109,16 +110,27 @@ cpdef tuple distinct(table, HashSet seen_hashes=None, list columns=None):

return (keep, seen_hashes)

@cython.boundscheck(False)

@cython.wraparound(False)
cpdef tuple list_distinct(list values, cnp.ndarray indices, HashSet seen_hashes=None):
new_indices = []
new_values = []
for i, v in enumerate(values):
if seen_hashes.insert(hash(v)):
new_values.append(v)
new_indices.append(indices[i])
return new_values, new_indices, seen_hashes
cpdef tuple list_distinct(cnp.ndarray values, cnp.ndarray indices, HashSet seen_hashes=None):
cdef cnp.ndarray[object] new_values = numpy.empty(values.shape[0], dtype=object)
cdef cnp.ndarray[int64_t] new_indices = numpy.empty(indices.shape[0], dtype=numpy.int64)

cdef int i, j = 0
cdef object v
cdef int64_t hash_value

if seen_hashes is None:
seen_hashes = HashSet()

for i in range(values.shape[0]):
v = values[i]
hash_value = <int64_t>hash(v)
if seen_hashes.insert(hash_value):
new_values[j] = v
new_indices[j] = indices[i]
j += 1
return new_values[:j], new_indices[:j], seen_hashes

@cython.boundscheck(False)
@cython.wraparound(False)
Expand All @@ -135,51 +147,163 @@ cpdef HashTable hash_join_map(relation, list join_columns):
values are lists of row indices corresponding to each hash key.
"""
cdef HashTable ht = HashTable()
cdef cnp.ndarray[uint8_t, ndim=1] bitmap_array

# Selecting columns

# Get the dimensions of the dataset we're working with
cdef int64_t num_rows = relation.num_rows
cdef int64_t num_columns = len(join_columns)

# Allocate memory for the combined nulls array
cdef cnp.ndarray[uint8_t, ndim=1] combined_nulls = numpy.full(num_rows, 1, dtype=numpy.uint8)
# Memory view for combined nulls (used to check for nulls in any column)
cdef uint8_t[:,] combined_nulls = numpy.full(num_rows, 1, dtype=numpy.uint8)

# Process each column to update the combined null bitmap
cdef int64_t i, col_index
cdef str column_name
cdef object column, bitmap_buffer
cdef int64_t i
cdef uint8_t bit, byte
cdef uint8_t[::1] bitmap_array

for column_name in join_columns:
column = relation.column(column_name)

if column.null_count > 0:
# Get the null bitmap for the current column, ensure it's in a single chunk first
# Get the null bitmap for the current column
bitmap_buffer = column.combine_chunks().buffers()[0]

if bitmap_buffer is not None:
# Convert the bitmap to uint8
# Memory view for the bitmap array
bitmap_array = numpy.frombuffer(bitmap_buffer, dtype=numpy.uint8)

# Apply bitwise operations on the bitmap
for i in range(num_rows):
byte = bitmap_array[i // 8]
bit = 1 if byte & (1 << (i % 8)) else 0
bit = (byte >> (i % 8)) & 1
combined_nulls[i] &= bit

# Determine row indices that have nulls in any of the considered columns
# Get non-null indices using memory views
cdef cnp.ndarray non_null_indices = numpy.nonzero(combined_nulls)[0]

# Convert selected columns to a numpy array of object dtype, skipping null cols
cdef cnp.ndarray values_array = numpy.array([relation.column(column).take(non_null_indices) for column in join_columns], dtype=object)
# Memory view for the values array (for the join columns)
cdef object[:, ::1] values_array = numpy.array(list(relation.take(non_null_indices).select(join_columns).itercolumns()), dtype=object)

cdef int64_t hash_value
cdef tuple value_tuple

if num_columns > 1:
for i in range(values_array.shape[1]):
# Create a tuple of values across the columns for the current row
value_tuple = tuple(values_array[:, i])
hash_value = <int64_t>hash(value_tuple)
if num_columns == 1:
col = values_array[0, :]
for i in range(len(col)):
hash_value = <int64_t>hash(col[i])
ht.insert(hash_value, non_null_indices[i])
else:
for i, value in enumerate(values_array[0]):
hash_value = <int64_t>hash(value)
for i in range(values_array.shape[1]):
# Combine the hashes of each value in the row
hash_value = 0
for value in values_array[:, i]:
hash_value = <int64_t>(hash_value * 31 + hash(value))
ht.insert(hash_value, non_null_indices[i])

return ht
return ht


"""
Below here is an incomplete attempt at rewriting the hash table builder to be faster.
The key points to make it faster are:
- specialized hashes for different column types
- more C native structures, relying on less Python
This is competitive but doesn't outright beat the above version and currently doesn't pass all of the tests
"""


import cython
import numpy as np
import pyarrow
from libc.stdint cimport int64_t
from libc.stdlib cimport malloc, free

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef HashTable _hash_join_map(relation, list join_columns):
"""
Build a hash table for join operations using column-based hashing.
Each column is hashed separately, and the results are combined efficiently.
Parameters:
relation: The pyarrow.Table to preprocess.
join_columns: A list of column names to join on.
Returns:
A HashTable where keys are combined hashes of the join column entries and
values are lists of row indices corresponding to each hash key.
"""
cdef HashTable ht = HashTable()
cdef int64_t num_rows = relation.num_rows
cdef int64_t num_columns = len(join_columns)

# Create an array to store column hashes
cdef int64_t* cell_hashes = <int64_t*>malloc(num_rows * num_columns * sizeof(int64_t))
if cell_hashes is NULL:
raise Exception("Unable to allocate memory")

# Allocate memory for the combined nulls array
cdef cnp.ndarray[uint8_t, ndim=1] combined_nulls = numpy.full(num_rows, 1, dtype=numpy.uint8)

# Process each column to update the combined null bitmap
cdef int64_t i, j, combined_hash
cdef object column, bitmap_buffer
cdef uint8_t bit, byte

for column_name in join_columns:
column = relation.column(column_name)

if column.null_count > 0:
combined_column = column.combine_chunks()
bitmap_buffer = combined_column.buffers()[0] # Get the null bitmap buffer

if bitmap_buffer is not None:
bitmap_array = np.frombuffer(bitmap_buffer, dtype=np.uint8)

for i in range(num_rows):
byte = bitmap_array[i // 8]
bit = (byte >> (i % 8)) & 1
combined_nulls[i] &= bit

# Determine row indices that have no nulls in any considered column
cdef cnp.ndarray non_null_indices = numpy.nonzero(combined_nulls)[0]

# Process each column by type
for j, column_name in enumerate(join_columns):
column = relation.column(column_name)

# Handle different PyArrow types
if pyarrow.types.is_string(column.type): # String column
for i in non_null_indices:
cell_hashes[j * num_rows + i] = hash(column[i].as_buffer().to_pybytes()) # Hash string
elif pyarrow.types.is_integer(column.type) or pyarrow.types.is_floating(column.type):
# Access the data buffer directly as a NumPy array
np_column = numpy.frombuffer(column.combine_chunks().buffers()[1], dtype=np.int64)
for i in non_null_indices:
cell_hashes[j * num_rows + i] = np_column[i] # Directly store as int64
elif pyarrow.types.is_boolean(column.type):
bitmap_buffer = column.buffers()[1] # Boolean values are stored in bitmap
bitmap_ptr = <uint8_t*>bitmap_buffer.address # Access the bitmap buffer
for i in non_null_indices:
byte_idx = i // 8
bit_idx = i % 8
bit_value = (bitmap_ptr[byte_idx] >> bit_idx) & 1
cell_hashes[j * num_rows + i] = bit_value # Convert to int64 (True -> 1, False -> 0)
elif pyarrow.types.is_date(column.type) or pyarrow.types.is_timestamp(column.type):
np_column = numpy.frombuffer(column.combine_chunks().buffers()[1], dtype=np.int64)
for i in non_null_indices:
cell_hashes[j * num_rows + i] = np_column[i] # Store as int64 timestamp

# Combine hash values (n * 31 + y pattern)
if num_columns == 1:
for i in non_null_indices:
ht.insert(cell_hashes[i], i)
else:
for i in non_null_indices:
combined_hash = 0
for j in range(num_columns):
combined_hash = combined_hash * 31 + cell_hashes[j * num_rows + i]
ht.insert(combined_hash, i) # Insert combined hash into the hash table

free(cell_hashes)
return ht
2 changes: 2 additions & 0 deletions opteryx/operators/cross_join_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ def _cross_join_unnest_column(
column_data.to_numpy(False)
)
else:
statistics.optimization_push_filters_into_cross_join_unnest = 1
indices, new_column_data = build_filtered_rows_indices_and_column(
column_data.to_numpy(False), conditions
)

if single_column and distinct and indices.size > 0:
# if the unnest target is the only field in the SELECT and we're DISTINCTING
statistics.optimization_push_distinct_into_cross_join_unnest = 1
new_column_data, indices, hash_set = list_distinct(
new_column_data, indices, hash_set
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ def visit(self, node: LogicalPlanNode, context: OptimizerContext) -> OptimizerCo
and context.collected_distincts
and node.type == "cross join"
and node.unnest_target is not None
and node.pre_update_columns == node.unnest_target.identity
and node.pre_update_columns == {node.unnest_target.identity}
):
# Very specifically testing for a DISTINCT on the unnested column, only.
# In this situation we do the DISTINCT on the intermediate results of the CJU,
# this means we create smaller tables out of the CROSS JOIN => faster
node.distinct = True
context.optimized_plan[context.node_id] = node
for distict_node in context.collected_distincts:
Expand Down
4 changes: 2 additions & 2 deletions opteryx/planner/temporary_physical_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def create_physical_plan(logical_plan, query_properties) -> ExecutionTree:
elif node_type == LogicalPlanStepType.Join:
if node_config.get("type") == "inner":
# We use our own implementation of INNER JOIN
# We have optimized INTEGER and VARCHAR versions
if len(node_config["left_columns"]) == 1 and node_config["columns"][0].schema_column.type in {OrsoTypes.INTEGER, OrsoTypes.VARCHAR}:
# We have optimized VARCHAR version
if len(node_config["left_columns"]) == 1 and node_config["columns"][0].schema_column.type == OrsoTypes.VARCHAR:
node = operators.InnerJoinSingleNode(query_properties, **node_config)
else:
node = operators.InnerJoinNode(query_properties, **node_config)
Expand Down
21 changes: 15 additions & 6 deletions tests/fuzzing/test_sql_fuzzer_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,23 @@ def generate_condition(columns):
def generate_random_sql_join(columns1, table1, columns2, table2) -> str:
join_type = random.choice(["INNER JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL OUTER JOIN"])

left_column = columns1[random.choice(range(len(columns1)))]
right_column = columns2[random.choice(range(len(columns2)))]
while left_column.type != right_column.type:
last_value = -1
this_value = random.random()
conditions = []
while this_value > last_value:
last_value = this_value
this_value = random.random()

left_column = columns1[random.choice(range(len(columns1)))]
right_column = columns2[random.choice(range(len(columns2)))]
while left_column.type != right_column.type:
left_column = columns1[random.choice(range(len(columns1)))]
right_column = columns2[random.choice(range(len(columns2)))]

condition = f"{table1}.{left_column.name} = {table2}.{right_column.name}"
conditions.append(condition)

join_condition = f"{table1}.{left_column.name} = {table2}.{right_column.name}"
join_condition = " AND ".join(conditions)
selected_columns = [f"{table1}.{col.name}" for col in columns1 if random.random() < 0.2] + [f"{table2}.{col.name}" for col in columns2 if random.random() < 0.2]
if len(selected_columns) == 0:
selected_columns = ["*"]
Expand Down Expand Up @@ -109,7 +119,6 @@ def generate_random_sql_join(columns1, table1, columns2, table2) -> str:
def test_sql_fuzzing_join(i):
seed = random_int()
random.seed(seed)
print(f"Seed: {seed}")

table1 = TABLES[random.choice(range(len(TABLES)))]
table2 = TABLES[random.choice(range(len(TABLES)))]
Expand All @@ -124,7 +133,7 @@ def test_sql_fuzzing_join(i):
try:
res = opteryx.query(statement)
execution_time = time.time() - start_time # Measure execution time
print(f"Shape: {res.shape}, Execution Time: {execution_time:.2f} seconds")
print(f"Shape: {res.shape}, Execution Time: {execution_time:.2f} seconds, Seed: {seed}")
# Additional success criteria checks can be added here
except Exception as e:
import traceback
Expand Down
3 changes: 1 addition & 2 deletions tests/query_execution/test_join_flaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys

sys.path.insert(1, os.path.join(sys.path[0], "../.."))
import pytest

import opteryx

Expand All @@ -13,7 +12,7 @@
def test_join_flaw():
"""
There was a flaw with the join algo that meant that nulls weren't handled correctly, it wasn't
consistent (about 1 in 5) so we hammer this query to help determine haveb't regressed this bug
consistent (about 1 in 5) so we hammer this query to help determine haven't regressed this bug.
"""
for i in range(100):
res = opteryx.query(SQL)
Expand Down

0 comments on commit 84e3e9c

Please sign in to comment.