Skip to content

Commit

Permalink
Refactor block load test
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Dec 2, 2024
1 parent 9f8e379 commit c3cc8c1
Showing 1 changed file with 34 additions and 83 deletions.
117 changes: 34 additions & 83 deletions cub/test/catch2_test_block_load.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,28 @@
#include <c2h/catch2_test_helper.h>

template <int ItemsPerThread, int ThreadsInBlock, cub::BlockLoadAlgorithm LoadAlgorithm>
struct output_idx
static __device__ int get_output_idx(int item)
{
static __device__ int get(int item)
{
return static_cast<int>(threadIdx.x) * ItemsPerThread + item;
}
};

template <int ItemsPerThread, int ThreadsInBlock>
struct output_idx<ItemsPerThread, ThreadsInBlock, cub::BlockLoadAlgorithm::BLOCK_LOAD_STRIPED>
{
static __device__ int get(int item)
if (LoadAlgorithm == cub::BlockLoadAlgorithm::BLOCK_LOAD_STRIPED)
{
return static_cast<int>(threadIdx.x) + ThreadsInBlock * item;
}
};
return static_cast<int>(threadIdx.x) * ItemsPerThread + item;
}

template <typename InputIteratorT,
typename OutputIteratorT,
int ItemsPerThread,
template <int ItemsPerThread,
int ThreadsInBlock,
cub::BlockLoadAlgorithm LoadAlgorithm>
__global__ void kernel(std::integral_constant<bool, true>, InputIteratorT input, OutputIteratorT output, int num_items)
cub::BlockLoadAlgorithm LoadAlgorithm,
typename InputIteratorT,
typename OutputIteratorT>
__global__ void kernel(cuda::std::true_type, InputIteratorT input, OutputIteratorT output, int num_items)
{
using input_t = cub::detail::value_t<InputIteratorT>;
using block_load_t = cub::BlockLoad<input_t, ThreadsInBlock, ItemsPerThread, LoadAlgorithm>;
using storage_t = typename block_load_t::TempStorage;

__shared__ storage_t storage;
block_load_t block_load(storage);

input_t data[ItemsPerThread];

if (ItemsPerThread * ThreadsInBlock == num_items)
Expand All @@ -77,50 +68,45 @@ __global__ void kernel(std::integral_constant<bool, true>, InputIteratorT input,

for (int i = 0; i < ItemsPerThread; i++)
{
const int idx = output_idx<ItemsPerThread, ThreadsInBlock, LoadAlgorithm>::get(i);

const int idx = get_output_idx<ItemsPerThread, ThreadsInBlock, LoadAlgorithm>(i);
if (idx < num_items)
{
output[idx] = data[i];
}
}
}

template <typename InputIteratorT,
typename OutputIteratorT,
int ItemsPerThread,
template <int ItemsPerThread,
int ThreadsInBlock,
cub::BlockLoadAlgorithm /* LoadAlgorithm */>
__global__ void kernel(std::integral_constant<bool, false>, InputIteratorT input, OutputIteratorT output, int num_items)
cub::BlockLoadAlgorithm /* LoadAlgorithm */,
typename InputIteratorT,
typename OutputIteratorT>
__global__ void kernel(cuda::std::false_type, InputIteratorT input, OutputIteratorT output, int num_items)
{
for (int i = 0; i < ItemsPerThread; i++)
{
const int idx = output_idx<ItemsPerThread, ThreadsInBlock, cub::BlockLoadAlgorithm::BLOCK_LOAD_DIRECT>::get(i);

const int idx = get_output_idx<ItemsPerThread, ThreadsInBlock, cub::BlockLoadAlgorithm::BLOCK_LOAD_DIRECT>(i);
if (idx < num_items)
{
output[idx] = input[idx];
}
}
}

template <int ItemsPerThread,
int ThreadsInBlock,
cub::BlockLoadAlgorithm LoadAlgorithm,
typename InputIteratorT,
typename OutputIteratorT>
void block_load(InputIteratorT input, OutputIteratorT output, int num_items)
template <int ItemsPerThread, int ThreadsInBlock, cub::BlockLoadAlgorithm LoadAlgorithm, typename T, typename InputIteratorT>
void test_block_load(const c2h::device_vector<T>& d_input, InputIteratorT input)
{
using input_t = cub::detail::value_t<InputIteratorT>;
using block_load_t = cub::BlockLoad<input_t, ThreadsInBlock, ItemsPerThread, LoadAlgorithm>;
using storage_t = typename block_load_t::TempStorage;
constexpr bool sufficient_resources = sizeof(storage_t) <= cub::detail::max_smem_per_block;

kernel<InputIteratorT, OutputIteratorT, ItemsPerThread, ThreadsInBlock, LoadAlgorithm>
<<<1, ThreadsInBlock>>>(std::integral_constant<bool, sufficient_resources>{}, input, output, num_items);
using block_load_t = cub::BlockLoad<T, ThreadsInBlock, ItemsPerThread, LoadAlgorithm>;
using storage_t = typename block_load_t::TempStorage;
constexpr auto sufficient_resources =
cuda::std::bool_constant<sizeof(storage_t) <= cub::detail::max_smem_per_block>{};

c2h::device_vector<T> d_output(d_input.size());
kernel<ItemsPerThread, ThreadsInBlock, LoadAlgorithm>
<<<1, ThreadsInBlock>>>(sufficient_resources, input, thrust::raw_pointer_cast(d_output.data()), d_input.size());
REQUIRE(cudaSuccess == cudaPeekAtLastError());
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
REQUIRE(d_input == d_output);
}

// %PARAM% IPT it 1:11
Expand Down Expand Up @@ -173,14 +159,8 @@ C2H_TEST("Block load works with even block sizes",
c2h::device_vector<type> d_input(GENERATE_COPY(take(10, random(0, params::tile_size))));
c2h::gen(C2H_SEED(10), d_input);

c2h::device_vector<type> d_output(d_input.size());

block_load<params::items_per_thread, params::threads_in_block, params::load_algorithm>(
thrust::raw_pointer_cast(d_input.data()),
thrust::raw_pointer_cast(d_output.data()),
static_cast<int>(d_input.size()));

REQUIRE(d_input == d_output);
test_block_load<params::items_per_thread, params::threads_in_block, params::load_algorithm>(
d_input, thrust::raw_pointer_cast(d_input.data()));
}

C2H_TEST("Block load works with even odd sizes",
Expand All @@ -195,15 +175,8 @@ C2H_TEST("Block load works with even odd sizes",

c2h::device_vector<type> d_input(GENERATE_COPY(take(10, random(0, params::tile_size))));
c2h::gen(C2H_SEED(10), d_input);

c2h::device_vector<type> d_output(d_input.size());

block_load<params::items_per_thread, params::threads_in_block, params::load_algorithm>(
thrust::raw_pointer_cast(d_input.data()),
thrust::raw_pointer_cast(d_output.data()),
static_cast<int>(d_input.size()));

REQUIRE(d_input == d_output);
test_block_load<params::items_per_thread, params::threads_in_block, params::load_algorithm>(
d_input, thrust::raw_pointer_cast(d_input.data()));
}

C2H_TEST(
Expand All @@ -214,15 +187,8 @@ C2H_TEST(

c2h::device_vector<type> d_input(GENERATE_COPY(take(10, random(0, params::tile_size))));
c2h::gen(C2H_SEED(10), d_input);

c2h::device_vector<type> d_output(d_input.size());

block_load<params::items_per_thread, params::threads_in_block, params::load_algorithm>(
thrust::raw_pointer_cast(d_input.data()),
thrust::raw_pointer_cast(d_output.data()),
static_cast<int>(d_input.size()));

REQUIRE(d_input == d_output);
test_block_load<params::items_per_thread, params::threads_in_block, params::load_algorithm>(
d_input, thrust::raw_pointer_cast(d_input.data()));
}

C2H_TEST("Block load works with custom types", "[load][block]", items_per_thread, load_algorithm)
Expand All @@ -235,15 +201,7 @@ C2H_TEST("Block load works with custom types", "[load][block]", items_per_thread

c2h::device_vector<type> d_input(GENERATE_COPY(take(10, random(0, tile_size))));
c2h::gen(C2H_SEED(10), d_input);

c2h::device_vector<type> d_output(d_input.size());

block_load<items_per_thread, threads_in_block, load_algorithm>(
thrust::raw_pointer_cast(d_input.data()),
thrust::raw_pointer_cast(d_output.data()),
static_cast<int>(d_input.size()));

REQUIRE(d_input == d_output);
test_block_load<items_per_thread, threads_in_block, load_algorithm>(d_input, thrust::raw_pointer_cast(d_input.data()));
}

C2H_TEST("Block load works with caching iterators", "[load][block]", items_per_thread, load_algorithm)
Expand All @@ -256,14 +214,7 @@ C2H_TEST("Block load works with caching iterators", "[load][block]", items_per_t

c2h::device_vector<type> d_input(GENERATE_COPY(take(10, random(0, tile_size))));
c2h::gen(C2H_SEED(10), d_input);

cub::CacheModifiedInputIterator<cub::CacheLoadModifier::LOAD_DEFAULT, type> in(
thrust::raw_pointer_cast(d_input.data()));

c2h::device_vector<type> d_output(d_input.size());

block_load<items_per_thread, threads_in_block, load_algorithm>(
in, thrust::raw_pointer_cast(d_output.data()), static_cast<int>(d_input.size()));

REQUIRE(d_input == d_output);
test_block_load<items_per_thread, threads_in_block, load_algorithm>(d_input, in);
}

0 comments on commit c3cc8c1

Please sign in to comment.