diff --git a/cub/test/catch2_test_block_load.cu b/cub/test/catch2_test_block_load.cu index 6ceb6bc43af..0aa105e3889 100644 --- a/cub/test/catch2_test_block_load.cu +++ b/cub/test/catch2_test_block_load.cu @@ -33,29 +33,21 @@ #include template -struct output_idx +static __device__ int get_output_idx(int item) { - static __device__ int get(int item) - { - return static_cast(threadIdx.x) * ItemsPerThread + item; - } -}; - -template -struct output_idx -{ - static __device__ int get(int item) + if (LoadAlgorithm == cub::BlockLoadAlgorithm::BLOCK_LOAD_STRIPED) { return static_cast(threadIdx.x) + ThreadsInBlock * item; } -}; + return static_cast(threadIdx.x) * ItemsPerThread + item; +} -template -__global__ void kernel(std::integral_constant, InputIteratorT input, OutputIteratorT output, int num_items) + cub::BlockLoadAlgorithm LoadAlgorithm, + typename InputIteratorT, + typename OutputIteratorT> +__global__ void kernel(cuda::std::true_type, InputIteratorT input, OutputIteratorT output, int num_items) { using input_t = cub::detail::value_t; using block_load_t = cub::BlockLoad; @@ -63,7 +55,6 @@ __global__ void kernel(std::integral_constant, InputIteratorT input, __shared__ storage_t storage; block_load_t block_load(storage); - input_t data[ItemsPerThread]; if (ItemsPerThread * ThreadsInBlock == num_items) @@ -77,8 +68,7 @@ __global__ void kernel(std::integral_constant, InputIteratorT input, for (int i = 0; i < ItemsPerThread; i++) { - const int idx = output_idx::get(i); - + const int idx = get_output_idx(i); if (idx < num_items) { output[idx] = data[i]; @@ -86,17 +76,16 @@ __global__ void kernel(std::integral_constant, InputIteratorT input, } } -template -__global__ void kernel(std::integral_constant, InputIteratorT input, OutputIteratorT output, int num_items) + cub::BlockLoadAlgorithm /* LoadAlgorithm */, + typename InputIteratorT, + typename OutputIteratorT> +__global__ void kernel(cuda::std::false_type, InputIteratorT input, OutputIteratorT output, int num_items) { for (int i = 0; i < ItemsPerThread; i++) { - const int idx = output_idx::get(i); - + const int idx = get_output_idx(i); if (idx < num_items) { output[idx] = input[idx]; @@ -104,23 +93,20 @@ __global__ void kernel(std::integral_constant, InputIteratorT input } } -template -void block_load(InputIteratorT input, OutputIteratorT output, int num_items) +template +void test_block_load(const c2h::device_vector& d_input, InputIteratorT input) { - using input_t = cub::detail::value_t; - using block_load_t = cub::BlockLoad; - using storage_t = typename block_load_t::TempStorage; - constexpr bool sufficient_resources = sizeof(storage_t) <= cub::detail::max_smem_per_block; - - kernel - <<<1, ThreadsInBlock>>>(std::integral_constant{}, input, output, num_items); + using block_load_t = cub::BlockLoad; + using storage_t = typename block_load_t::TempStorage; + constexpr auto sufficient_resources = + cuda::std::bool_constant{}; + c2h::device_vector d_output(d_input.size()); + kernel + <<<1, ThreadsInBlock>>>(sufficient_resources, input, thrust::raw_pointer_cast(d_output.data()), d_input.size()); REQUIRE(cudaSuccess == cudaPeekAtLastError()); REQUIRE(cudaSuccess == cudaDeviceSynchronize()); + REQUIRE(d_input == d_output); } // %PARAM% IPT it 1:11 @@ -173,14 +159,8 @@ C2H_TEST("Block load works with even block sizes", c2h::device_vector d_input(GENERATE_COPY(take(10, random(0, params::tile_size)))); c2h::gen(C2H_SEED(10), d_input); - c2h::device_vector d_output(d_input.size()); - - block_load( - thrust::raw_pointer_cast(d_input.data()), - thrust::raw_pointer_cast(d_output.data()), - static_cast(d_input.size())); - - REQUIRE(d_input == d_output); + test_block_load( + d_input, thrust::raw_pointer_cast(d_input.data())); } C2H_TEST("Block load works with even odd sizes", @@ -195,15 +175,8 @@ C2H_TEST("Block load works with even odd sizes", c2h::device_vector d_input(GENERATE_COPY(take(10, random(0, params::tile_size)))); c2h::gen(C2H_SEED(10), d_input); - - c2h::device_vector d_output(d_input.size()); - - block_load( - thrust::raw_pointer_cast(d_input.data()), - thrust::raw_pointer_cast(d_output.data()), - static_cast(d_input.size())); - - REQUIRE(d_input == d_output); + test_block_load( + d_input, thrust::raw_pointer_cast(d_input.data())); } C2H_TEST( @@ -214,15 +187,8 @@ C2H_TEST( c2h::device_vector d_input(GENERATE_COPY(take(10, random(0, params::tile_size)))); c2h::gen(C2H_SEED(10), d_input); - - c2h::device_vector d_output(d_input.size()); - - block_load( - thrust::raw_pointer_cast(d_input.data()), - thrust::raw_pointer_cast(d_output.data()), - static_cast(d_input.size())); - - REQUIRE(d_input == d_output); + test_block_load( + d_input, thrust::raw_pointer_cast(d_input.data())); } C2H_TEST("Block load works with custom types", "[load][block]", items_per_thread, load_algorithm) @@ -235,15 +201,7 @@ C2H_TEST("Block load works with custom types", "[load][block]", items_per_thread c2h::device_vector d_input(GENERATE_COPY(take(10, random(0, tile_size)))); c2h::gen(C2H_SEED(10), d_input); - - c2h::device_vector d_output(d_input.size()); - - block_load( - thrust::raw_pointer_cast(d_input.data()), - thrust::raw_pointer_cast(d_output.data()), - static_cast(d_input.size())); - - REQUIRE(d_input == d_output); + test_block_load(d_input, thrust::raw_pointer_cast(d_input.data())); } C2H_TEST("Block load works with caching iterators", "[load][block]", items_per_thread, load_algorithm) @@ -256,14 +214,7 @@ C2H_TEST("Block load works with caching iterators", "[load][block]", items_per_t c2h::device_vector d_input(GENERATE_COPY(take(10, random(0, tile_size)))); c2h::gen(C2H_SEED(10), d_input); - cub::CacheModifiedInputIterator in( thrust::raw_pointer_cast(d_input.data())); - - c2h::device_vector d_output(d_input.size()); - - block_load( - in, thrust::raw_pointer_cast(d_output.data()), static_cast(d_input.size())); - - REQUIRE(d_input == d_output); + test_block_load(d_input, in); }