Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
shwina committed Dec 2, 2024
1 parent af0a8bb commit 46e75e9
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions python/cuda_parallel/cuda/parallel/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import functools
import importlib
import ctypes
import shutil
Expand Down Expand Up @@ -89,6 +90,7 @@ class _CCCLValue(ctypes.Structure):
("state", ctypes.c_void_p)]


@functools.cache
def _type_to_info(numpy_type):
numba_type = numba.from_dtype(numpy_type)
context = cuda.descriptor.cuda_target.target_context
Expand Down Expand Up @@ -194,14 +196,15 @@ def _dtype_validation(dt1, dt2):


class _Reduce:
def __init__(self, d_in, d_out, op, init):
# TODO: constructor shouldn't require concrete `d_in`, `d_out`:
def __init__(self, d_in, d_out, op, h_init):
self._ctor_d_in_dtype = d_in.dtype
self._ctor_d_out_dtype = d_out.dtype
self._ctor_init_dtype = init.dtype
self._ctor_init_dtype = h_init.dtype
cc_major, cc_minor = cuda.get_current_device().compute_capability
cub_path, thrust_path, libcudacxx_path, cuda_include_path = _get_paths()
bindings = _get_bindings()
accum_t = init.dtype
accum_t = h_init.dtype
self.op_wrapper = _Op(accum_t, op)
d_in_ptr = _device_array_to_pointer(d_in)
d_out_ptr = _device_array_to_pointer(d_out)
Expand All @@ -212,7 +215,7 @@ def __init__(self, d_in, d_out, op, init):
d_in_ptr,
d_out_ptr,
self.op_wrapper.handle(),
_host_array_to_value(init),
_host_array_to_value(h_init),
cc_major,
cc_minor,
ctypes.c_char_p(cub_path),
Expand All @@ -223,11 +226,11 @@ def __init__(self, d_in, d_out, op, init):
if error != enums.CUDA_SUCCESS:
raise ValueError('Error building reduce')

def __call__(self, temp_storage, d_in, d_out, init):
def __call__(self, temp_storage, d_in, d_out, h_init):
# TODO validate POINTER vs ITERATOR when iterator support is added
_dtype_validation(self._ctor_d_in_dtype, d_in.dtype)
_dtype_validation(self._ctor_d_out_dtype, d_out.dtype)
_dtype_validation(self._ctor_init_dtype, init.dtype)
_dtype_validation(self._ctor_init_dtype, h_init.dtype)
bindings = _get_bindings()
if temp_storage is None:
temp_storage_bytes = ctypes.c_size_t()
Expand All @@ -247,7 +250,7 @@ def __call__(self, temp_storage, d_in, d_out, init):
d_out_ptr,
num_items,
self.op_wrapper.handle(),
_host_array_to_value(init),
_host_array_to_value(h_init),
None)
if error != enums.CUDA_SUCCESS:
raise ValueError('Error reducing')
Expand All @@ -259,10 +262,25 @@ def __del__(self):
bindings.cccl_device_reduce_cleanup(ctypes.byref(self.build_result))


def _get_reducer():
# note: currently this is an unbounded cache:
cache = {}
def inner(d_in, d_out, op, init):
key = (d_in.dtype, d_out.dtype, op, init.dtype)
if key in cache:
return cache[key]
else:
result = _Reduce(d_in, d_out, op, init)
cache[key] = result
return result
return inner

get_reducer = _get_reducer()

# TODO Figure out iterators
# TODO Figure out `sum` without operator and initial value
# TODO Accept stream
def reduce_into(d_in, d_out, op, init):
def reduce_into(d_in, d_out, op, h_init):
"""Computes a device-wide reduction using the specified binary ``op`` functor and initial value ``init``.
Example:
Expand Down Expand Up @@ -292,4 +310,4 @@ def reduce_into(d_in, d_out, op, init):
Returns:
A callable object that can be used to perform the reduction
"""
return _Reduce(d_in, d_out, op, init)
return get_reducer(d_in, d_out, op, h_init)

0 comments on commit 46e75e9

Please sign in to comment.