Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PERF] cuda.parallel: Cache intermediate results to improve performance of cudax.reduce_into #3001

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions python/cuda_parallel/cuda/parallel/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import functools
import importlib
import ctypes
import shutil
Expand Down Expand Up @@ -89,6 +90,9 @@ class _CCCLValue(ctypes.Structure):
("state", ctypes.c_void_p)]


# TODO: replace with functools.cache once our docs build environment
# is upgraded to at least Python 3.9
@functools.lru_cache(maxsize=None)
def _type_to_info(numpy_type):
numba_type = numba.from_dtype(numpy_type)
context = cuda.descriptor.cuda_target.target_context
Expand Down Expand Up @@ -194,14 +198,15 @@ def _dtype_validation(dt1, dt2):


class _Reduce:
def __init__(self, d_in, d_out, op, init):
# TODO: constructor shouldn't require concrete `d_in`, `d_out`:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I was wondering about, but I didn't get to drilling down.

It might be useful to work together on getting this TODO done.

Already, this PR will have complicated merge conflicts with my #2788, i.e. it might be best to team up working on both.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to work together on getting this TODO done.

Definitely! I opened #3008 to track this. Let's tackle it as a follow up to this PR.

def __init__(self, d_in, d_out, op, h_init):
self._ctor_d_in_dtype = d_in.dtype
self._ctor_d_out_dtype = d_out.dtype
self._ctor_init_dtype = init.dtype
self._ctor_init_dtype = h_init.dtype
cc_major, cc_minor = cuda.get_current_device().compute_capability
cub_path, thrust_path, libcudacxx_path, cuda_include_path = _get_paths()
bindings = _get_bindings()
accum_t = init.dtype
accum_t = h_init.dtype
self.op_wrapper = _Op(accum_t, op)
d_in_ptr = _device_array_to_pointer(d_in)
d_out_ptr = _device_array_to_pointer(d_out)
Expand All @@ -212,7 +217,7 @@ def __init__(self, d_in, d_out, op, init):
d_in_ptr,
d_out_ptr,
self.op_wrapper.handle(),
_host_array_to_value(init),
_host_array_to_value(h_init),
cc_major,
cc_minor,
ctypes.c_char_p(cub_path),
Expand All @@ -223,11 +228,11 @@ def __init__(self, d_in, d_out, op, init):
if error != enums.CUDA_SUCCESS:
raise ValueError('Error building reduce')

def __call__(self, temp_storage, d_in, d_out, init):
def __call__(self, temp_storage, d_in, d_out, h_init):
# TODO validate POINTER vs ITERATOR when iterator support is added
_dtype_validation(self._ctor_d_in_dtype, d_in.dtype)
_dtype_validation(self._ctor_d_out_dtype, d_out.dtype)
_dtype_validation(self._ctor_init_dtype, init.dtype)
_dtype_validation(self._ctor_init_dtype, h_init.dtype)
bindings = _get_bindings()
if temp_storage is None:
temp_storage_bytes = ctypes.c_size_t()
Expand All @@ -247,7 +252,7 @@ def __call__(self, temp_storage, d_in, d_out, init):
d_out_ptr,
num_items,
self.op_wrapper.handle(),
_host_array_to_value(init),
_host_array_to_value(h_init),
None)
if error != enums.CUDA_SUCCESS:
raise ValueError('Error reducing')
Expand All @@ -259,10 +264,25 @@ def __del__(self):
bindings.cccl_device_reduce_cleanup(ctypes.byref(self.build_result))


def _get_reducer():
# note: currently this is an unbounded cache:
cache = {}
def inner(d_in, d_out, op, init):
key = (d_in.dtype, d_out.dtype, op, init.dtype)
if key in cache:
return cache[key]
else:
result = _Reduce(d_in, d_out, op, init)
cache[key] = result
return result
return inner

get_reducer = _get_reducer()

# TODO Figure out iterators
# TODO Figure out `sum` without operator and initial value
# TODO Accept stream
def reduce_into(d_in, d_out, op, init):
def reduce_into(d_in, d_out, op, h_init):
"""Computes a device-wide reduction using the specified binary ``op`` functor and initial value ``init``.

Example:
Expand Down Expand Up @@ -292,4 +312,4 @@ def reduce_into(d_in, d_out, op, init):
Returns:
A callable object that can be used to perform the reduction
"""
return _Reduce(d_in, d_out, op, init)
return get_reducer(d_in, d_out, op, h_init)
Loading