Skip to content

Commit

Permalink
Cache result of _Reduce constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
shwina committed Dec 2, 2024
1 parent b35f441 commit ee7fcc9
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions python/cuda_parallel/cuda/parallel/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,25 @@ def __del__(self):
bindings.cccl_device_reduce_cleanup(ctypes.byref(self.build_result))


def _get_reducer():
# note: currently this is an unbounded cache:
cache = {}
def inner(d_in, d_out, op, init):
key = (d_in.dtype, d_out.dtype, op, init.dtype)
if key in cache:
return cache[key]
else:
result = _Reduce(d_in, d_out, op, init)
cache[key] = result
return result
return inner

get_reducer = _get_reducer()

# TODO Figure out iterators
# TODO Figure out `sum` without operator and initial value
# TODO Accept stream
def reduce_into(d_in, d_out, op, init):
def reduce_into(d_in, d_out, op, h_init):
"""Computes a device-wide reduction using the specified binary ``op`` functor and initial value ``init``.
Example:
Expand Down Expand Up @@ -295,4 +310,4 @@ def reduce_into(d_in, d_out, op, init):
Returns:
A callable object that can be used to perform the reduction
"""
return _Reduce(d_in, d_out, op, init)
return get_reducer(d_in, d_out, op, h_init)

0 comments on commit ee7fcc9

Please sign in to comment.