diff --git a/src/ocl/tensorocl.cpp b/src/ocl/tensorocl.cpp index 985861c09f..81d735a2bc 100644 --- a/src/ocl/tensorocl.cpp +++ b/src/ocl/tensorocl.cpp @@ -2135,8 +2135,10 @@ void CastTensor(const Handle& handle, MIOPEN_THROW(miopenStatusBadParm, "Tensor dimension sizes unsupported."); } + auto miopen_alpha = *(static_cast(alpha)); + if(srcDesc.GetType() == dstDesc.GetType() && srcOffset == 0 && dstOffset == 0 && - srcDesc_flat.IsPacked() && dstDesc_flat.IsPacked()) + srcDesc_flat.IsPacked() && dstDesc_flat.IsPacked() && float_equal(miopen_alpha, 1.0)) { handle.Copy(src, dst, srcDesc_flat.GetElementSize() * GetTypeSize(srcDesc_flat.GetType())); } @@ -2146,7 +2148,9 @@ void CastTensor(const Handle& handle, const std::vector& lens = srcDesc_flat.GetLengths(); - std::string network_config = "cast " + std::to_string(dstDesc_flat.GetType()); + // TODO: make proper network config + std::string network_config = "cast " + std::to_string(srcDesc_flat.GetType()) + + std::to_string(dstDesc_flat.GetType()); for(auto& len : lens) { network_config += " " + std::to_string(len); @@ -2155,8 +2159,6 @@ void CastTensor(const Handle& handle, auto&& kernels = handle.GetKernels(kernel_name, network_config); KernelInvoke kernel; - auto miopen_alpha = *(static_cast(alpha)); - if(!kernels.empty()) { kernel = kernels.front(); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 796078b2b0..a3f2ec1b80 100755 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -265,11 +265,10 @@ option( WORKAROUND_ISSUE_1148 "" ${WORKAROUND_ISSUE_1148_DEFAULT}) if(MIOPEN_TEST_INT8) set(SKIP_ALL_EXCEPT_TESTS - test_tensor_vec test_tensor_cast test_tensor_trans test_tensor_copy test_tensor_set - test_tensor_transform test_conv2d test_conv2d_find2) + test_tensor_vec test_tensor_trans test_tensor_transform test_conv2d test_conv2d_find2) elseif(MIOPEN_TEST_BFLOAT16) set(SKIP_ALL_EXCEPT_TESTS - test_conv2d test_conv2d_find2 test_tensor_copy test_tensor_set test_tensor_vec test_immed_conv2d + test_conv2d test_conv2d_find2 test_tensor_vec test_immed_conv2d test_check_numerics_test test_conv_extra test_conv_for_implicit_gemm test_miopen_conv test_deepbench_conv test_conv_igemm_dynamic_xdlops_nhwc_wrw_bf16 test_conv_igemm_dynamic_xdlops_nhwc_fwd_bf16 diff --git a/test/conv2d_bias.cpp b/test/conv2d_bias.cpp index 62389fe69d..6e219609df 100644 --- a/test/conv2d_bias.cpp +++ b/test/conv2d_bias.cpp @@ -35,7 +35,7 @@ struct conv2d_bias_driver : public conv_bias_driver tensor_elem_gen_checkboard_sign{}(is...); }; - this->add(this->output, "output", this->get_tensor(get_inputs, gen_value)); + this->add(this->output, "output", this->get_tensor(get_inputs, gen_value)); } }; diff --git a/test/conv3d_bias.cpp b/test/conv3d_bias.cpp index f2cfadfd1c..cfd4891484 100644 --- a/test/conv3d_bias.cpp +++ b/test/conv3d_bias.cpp @@ -35,7 +35,8 @@ struct conv3d_bias_driver : public conv_bias_driver tensor_elem_gen_checkboard_sign{}(is...); }; - this->add(this->output, "output", this->get_tensor(get_3d_conv_input_shapes, gen_value)); + this->add( + this->output, "output", this->get_tensor(get_3d_conv_input_shapes, gen_value)); } }; diff --git a/test/gtest/binary_tensor_ops.cpp b/test/gtest/binary_tensor_ops.cpp new file mode 100644 index 0000000000..ae3a06330b --- /dev/null +++ b/test/gtest/binary_tensor_ops.cpp @@ -0,0 +1,295 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include +#include + +#include "get_handle.hpp" +#include "tensor_holder.hpp" +#include "tensor_util.hpp" +#include "verify.hpp" + +namespace { +using BinaryTensorOpsCase = std::tuple, std::vector, float, bool>; + +template +class GPU_binaryTensorOps : public ::testing::TestWithParam +{ +public: + static tensor dstSuperTensor; + static tensor srcSuperTensor; + + static void SetUpTestSuite() + { + static constexpr auto dstType = miopen_type{}; + static constexpr auto srcType = miopen_type{}; + + uint64_t dstMaxValue = dstType == miopenHalf ? 5 : (dstType == miopenInt8 ? 126 : 32767); + uint64_t srcMaxValue = srcType == miopenHalf ? 5 : (srcType == miopenInt8 ? 126 : 32767); + + uint64_t maxValue = std::min(dstMaxValue, srcMaxValue); + + dstSuperTensor = tensor{std::vector{32, 32, 16, 16, 16}}.generate( + tensor_elem_gen_integer{maxValue}); + + srcSuperTensor = tensor{std::vector{32, 16, 32, 16, 16}}.generate( + tensor_elem_gen_integer{maxValue}); + } + +protected: + size_t dstDataSize; + size_t srcDataSize; + miopen::TensorDescriptor dstDesc; + miopen::TensorDescriptor srcDesc; + + void SetUp() override + { + const auto& [lens, offsets, alpha, clamp] = GetParam(); + + ASSERT_GE(dstSuperTensor.desc.GetNumDims(), lens.size()); + + const std::vector& dstSuperStrides = dstSuperTensor.desc.GetStrides(); + std::vector dstStrides(dstSuperStrides.begin() + + (dstSuperTensor.desc.GetNumDims() - lens.size()), + dstSuperStrides.end()); + + dstDesc = miopen::TensorDescriptor(miopen_type{}, lens, dstStrides); + dstDataSize = dstDesc.GetElementSpace() + offsets[1]; + + ASSERT_GE(srcSuperTensor.desc.GetElementSpace(), dstDataSize); + + ASSERT_GE(srcSuperTensor.desc.GetNumDims(), lens.size()); + + const std::vector& srcSuperStrides = srcSuperTensor.desc.GetStrides(); + std::vector srcStrides(srcSuperStrides.begin() + + (srcSuperTensor.desc.GetNumDims() - lens.size()), + srcSuperStrides.end()); + + srcDesc = miopen::TensorDescriptor(miopen_type{}, lens, srcStrides); + srcDataSize = srcDesc.GetElementSpace() + offsets[0]; + + ASSERT_GE(srcSuperTensor.desc.GetElementSpace(), srcDataSize); + } + + void RunCast() + { + std::vector dstSuperCpu(dstSuperTensor.begin(), + dstSuperTensor.begin() + dstDataSize); + std::vector srcSuperCpu(srcSuperTensor.begin(), + srcSuperTensor.begin() + srcDataSize); + + const auto [srcOffset, dstOffset] = miopen::tien<2>(std::get>(GetParam())); + const auto alpha = std::get(GetParam()); + const auto clamp = std::get(GetParam()); + + auto&& handle = get_handle(); + auto dstSuper_dev = handle.Write(dstSuperCpu); + auto srcSuper_dev = handle.Write(srcSuperCpu); + + miopen::CastTensor(handle, + &alpha, + clamp, + srcDesc, + srcSuper_dev.get(), + dstDesc, + dstSuper_dev.get(), + srcOffset, + dstOffset); + + auto result = handle.Read(dstSuper_dev, dstDataSize); + + if(clamp) + { + operate_over_subtensor( + [alpha, clampVal = static_cast(std::numeric_limits::max())]( + auto& dst, auto src) { + dst = std::min(static_cast(src) * alpha, clampVal); + }, + dstSuperCpu, + srcSuperCpu, + dstDesc, + srcDesc, + dstOffset, + srcOffset); + } + else + { + operate_over_subtensor( + [alpha](auto& dst, auto src) { dst = static_cast(src) * alpha; }, + dstSuperCpu, + srcSuperCpu, + dstDesc, + srcDesc, + dstOffset, + srcOffset); + } + + auto mismatch_index = miopen::mismatch_idx(dstSuperCpu, result, miopen::float_equal); + auto mismatch_src_index = mismatch_index - dstOffset + srcOffset; + + ASSERT_EQ(result.size(), mismatch_index) + << "The first mismatched elements are:" // + << " Src[" << mismatch_src_index << "] " << srcSuperCpu[mismatch_src_index] // + << " GPU[" << mismatch_index << "] " << result[mismatch_index] // + << " Ref[" << mismatch_index << "] " << dstSuperCpu[mismatch_index]; // + } + + void RunCopy() + { + std::vector dstSuperCpu(dstSuperTensor.begin(), + dstSuperTensor.begin() + dstDataSize); + std::vector srcSuperCpu(srcSuperTensor.begin(), + srcSuperTensor.begin() + srcDataSize); + + const auto [srcOffset, dstOffset] = miopen::tien<2>(std::get>(GetParam())); + + auto&& handle = get_handle(); + auto dstSuper_dev = handle.Write(dstSuperCpu); + auto srcSuper_dev = handle.Write(srcSuperCpu); + + miopen::CopyTensor( + handle, srcDesc, srcSuper_dev.get(), dstDesc, dstSuper_dev.get(), srcOffset, dstOffset); + + auto result = handle.Read(dstSuper_dev, dstDataSize); + + operate_over_subtensor([](auto& dst, auto src) { dst = src; }, + dstSuperCpu, + srcSuperCpu, + dstDesc, + srcDesc, + dstOffset, + srcOffset); + + auto mismatch_index = miopen::mismatch_idx(dstSuperCpu, result, miopen::float_equal); + auto mismatch_src_index = mismatch_index - dstOffset + srcOffset; + + ASSERT_EQ(result.size(), mismatch_index) + << "The first mismatched elements are:" // + << " Src[" << mismatch_src_index << "] " << srcSuperCpu[mismatch_src_index] // + << " GPU[" << mismatch_index << "] " << result[mismatch_index] // + << " Ref[" << mismatch_index << "] " << dstSuperCpu[mismatch_index]; // + } + + void TearDown() override {} +}; + +template +tensor GPU_binaryTensorOps::dstSuperTensor; + +template +tensor GPU_binaryTensorOps::srcSuperTensor; +} // namespace + +using float16 = half_float::half; + +#define X_CONCAT_FIRST_SECOND_(first, second) first##second + +#define X_INSTANTIATE_CAST(TEST_TYPE, DST_TYPE, SRC_TYPE, ...) \ + using GPU_binaryTensorOps_cast_##SRC_TYPE##_##TEST_TYPE = \ + GPU_binaryTensorOps; \ + TEST_P(GPU_binaryTensorOps_cast_##SRC_TYPE##_##TEST_TYPE, \ + X_CONCAT_FIRST_SECOND_(__VA_ARGS__, TestTensorCast)) \ + { \ + RunCast(); \ + }; \ + \ + INSTANTIATE_TEST_SUITE_P( \ + Smoke, \ + GPU_binaryTensorOps_cast_##SRC_TYPE##_##TEST_TYPE, \ + testing::Combine(testing::Values(std::vector{32, 8, 10}), \ + testing::Values(std::vector{7, 11}), \ + testing::ValuesIn({1.0f / 127 / 127, 1.0f / 127, 127.0f, 1.0f}), \ + testing::Values(true, false))); \ + \ + INSTANTIATE_TEST_SUITE_P( \ + Full, \ + GPU_binaryTensorOps_cast_##SRC_TYPE##_##TEST_TYPE, \ + testing::Combine(testing::ValuesIn(get_sub_tensor()), \ + testing::ValuesIn(get_tensor_offsets()), \ + testing::ValuesIn({1.0f / 127 / 127, 1.0f / 127, 127.0f, 1.0f}), \ + testing::Values(true, false))); + +X_INSTANTIATE_CAST(FP32, float, float); +X_INSTANTIATE_CAST(FP16, float16, float); +X_INSTANTIATE_CAST(BFP16, bfloat16, float); +X_INSTANTIATE_CAST(I32, int, float); +X_INSTANTIATE_CAST(I8, int8_t, float); + +X_INSTANTIATE_CAST(FP32, float, float16); +X_INSTANTIATE_CAST(FP16, float16, float16); +X_INSTANTIATE_CAST(BFP16, bfloat16, float16, DISABLED_); +X_INSTANTIATE_CAST(I32, int, float16); +X_INSTANTIATE_CAST(I8, int8_t, float16); + +X_INSTANTIATE_CAST(FP32, + float, + bfloat16, + DISABLED_); // bfp16 is just broken except float->bfp16 case +X_INSTANTIATE_CAST(FP16, float16, bfloat16, DISABLED_); +X_INSTANTIATE_CAST(BFP16, bfloat16, bfloat16, DISABLED_); +X_INSTANTIATE_CAST(I32, int, bfloat16, DISABLED_); +X_INSTANTIATE_CAST(I8, int8_t, bfloat16, DISABLED_); + +X_INSTANTIATE_CAST(FP32, float, int); +X_INSTANTIATE_CAST(FP16, float16, int); +X_INSTANTIATE_CAST(BFP16, bfloat16, int, DISABLED_); +X_INSTANTIATE_CAST(I32, int, int); +X_INSTANTIATE_CAST(I8, int8_t, int); + +X_INSTANTIATE_CAST(FP32, float, int8_t); +X_INSTANTIATE_CAST(FP16, float16, int8_t); +X_INSTANTIATE_CAST(BFP16, bfloat16, int8_t, DISABLED_); +X_INSTANTIATE_CAST(I32, int, int8_t); +X_INSTANTIATE_CAST(I8, int8_t, int8_t); + +#undef X_INSTANTIATE_CAST +#undef X_CONCAT_FIRST_SECOND_ + +#define X_INSTANTIATE_COPY(TEST_TYPE, REAL_TYPE) \ + using GPU_binaryTensorOps_copy_##TEST_TYPE = GPU_binaryTensorOps; \ + TEST_P(GPU_binaryTensorOps_copy_##TEST_TYPE, TestTensorCopy) { RunCopy(); }; \ + \ + INSTANTIATE_TEST_SUITE_P(Smoke, \ + GPU_binaryTensorOps_copy_##TEST_TYPE, \ + testing::Combine(testing::Values(std::vector{32, 8, 10}), \ + testing::Values(std::vector{7, 11}), \ + testing::Values(0.0f), \ + testing::Values(false))); \ + \ + INSTANTIATE_TEST_SUITE_P(Full, \ + GPU_binaryTensorOps_copy_##TEST_TYPE, \ + testing::Combine(testing::ValuesIn(get_sub_tensor()), \ + testing::ValuesIn(get_tensor_offsets()), \ + testing::Values(0.0f), \ + testing::Values(false))); + +X_INSTANTIATE_COPY(FP32, float); +X_INSTANTIATE_COPY(FP16, float16); +X_INSTANTIATE_COPY(BFP16, bfloat16); + +#undef X_INSTANTIATE_COPY diff --git a/test/gtest/unary_tensor_ops.cpp b/test/gtest/unary_tensor_ops.cpp new file mode 100644 index 0000000000..7ee680aef2 --- /dev/null +++ b/test/gtest/unary_tensor_ops.cpp @@ -0,0 +1,143 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include +#include + +#include "get_handle.hpp" +#include "tensor_holder.hpp" +#include "tensor_util.hpp" +#include "verify.hpp" + +namespace { +using UnaryTensorOpsCase = std::tuple, int>; + +template +class GPU_unaryTensorOps : public ::testing::TestWithParam +{ +public: + static tensor superTensor; + + static void SetUpTestSuite() + { + uint64_t max_value = miopen_type{} == miopenHalf ? 5 : 17; + superTensor = tensor{std::vector{32, 32, 16, 16, 16}}.generate( + tensor_elem_gen_integer{max_value}); + } + +protected: + const T alpha = static_cast(2.048); + size_t dataSize; + miopen::TensorDescriptor subDesc; + + void SetUp() override + { + const auto& [lens, offset] = GetParam(); + ASSERT_GE(superTensor.desc.GetNumDims(), lens.size()); + + const std::vector& superStrides = superTensor.desc.GetStrides(); + std::vector strides(superStrides.begin() + + (superTensor.desc.GetNumDims() - lens.size()), + superStrides.end()); + + subDesc = miopen::TensorDescriptor(miopen_type{}, lens, strides); + dataSize = subDesc.GetElementSpace() + offset; + ASSERT_GE(superTensor.desc.GetElementSpace(), dataSize); + } + + template + void Run(DataOp&& dataOp, GpuOp&& gpuOp) + { + std::vector superCpu(superTensor.begin(), superTensor.begin() + dataSize); + auto offset = std::get(GetParam()); + + auto&& handle = get_handle(); + auto super_dev = handle.Write(superCpu); + gpuOp(handle, subDesc, super_dev.get(), &alpha, offset); + auto result = handle.Read(super_dev, dataSize); + + operate_over_subtensor(dataOp, superCpu, subDesc, offset); + + auto mismatch_index = miopen::mismatch_idx(superCpu, result, miopen::float_equal); + + ASSERT_EQ(result.size(), mismatch_index) + << "The first mismatched elements are:" // + << " GPU[" << mismatch_index << "] " << result[mismatch_index] // + << " Ref[" << mismatch_index << "] " << superCpu[mismatch_index]; // + } + + void RunScale() + { + Run([a = alpha](auto& val) { val *= a; }, + [](auto&&... params) { miopen::ScaleTensor(params...); }); + } + + void RunSet() + { + Run([a = alpha](auto& val) { val = a; }, + [](auto&&... params) { miopen::SetTensor(params...); }); + } + + void TearDown() override {} +}; + +template +tensor GPU_unaryTensorOps::superTensor; + +} // namespace + +using float16 = half_float::half; + +#define X_CONCAT_FIRST_SECOND_(first, second) first##second + +#define X_INSTANTIATE(TEST_TYPE, REAL_TYPE, ...) \ + using GPU_unaryTensorOps_##TEST_TYPE = GPU_unaryTensorOps; \ + TEST_P(GPU_unaryTensorOps_##TEST_TYPE, X_CONCAT_FIRST_SECOND_(__VA_ARGS__, TestTensorScale)) \ + { \ + RunScale(); \ + }; \ + TEST_P(GPU_unaryTensorOps_##TEST_TYPE, TestTensorSet) { RunSet(); }; \ + \ + INSTANTIATE_TEST_SUITE_P( \ + Smoke, \ + GPU_unaryTensorOps_##TEST_TYPE, \ + testing::Combine(testing::Values(std::vector{32, 8, 10}), testing::Values(7))); \ + \ + INSTANTIATE_TEST_SUITE_P(Full, \ + GPU_unaryTensorOps_##TEST_TYPE, \ + testing::Combine(testing::ValuesIn(get_sub_tensor()), \ + testing::ValuesIn(get_tensor_offset()))); + +X_INSTANTIATE(FP32, float); +X_INSTANTIATE(FP16, float16); +X_INSTANTIATE(I32, int); +X_INSTANTIATE(I8, int8_t, DISABLED_); // disable Scale for int8 +X_INSTANTIATE(BFP16, bfloat16, DISABLED_); // disable Scale for bfloat16 + +#undef X_INSTANTIATE +#undef X_CONCAT_FIRST_SECOND_ diff --git a/test/network_data.hpp b/test/network_data.hpp index e0a39323e3..09ae13762f 100644 --- a/test/network_data.hpp +++ b/test/network_data.hpp @@ -30,26 +30,26 @@ #include #include #include +#include #ifndef MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR #define MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR 0 #endif -inline int pick_batch_size(int x, int y) +template +inline constexpr T pick_batch_size(T x, T y) { - if(y == 0 || y > x) - return 1; - else - return x / y; + return (y == 0 || y > x) ? 1 : x / y; } // Reduce tests execution time #define MIOPEN_TESTS_GET_INPUTS_ENABLE_HUGE_TENSORS 1 -inline std::set> get_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) +template +inline std::set> get_inputs(T n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) { // clang-format off - return + return { { pick_batch_size(32, n), 1, 14, 14 }, { pick_batch_size(100, n), 1, 8, 8 }, @@ -103,10 +103,11 @@ inline std::set> get_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_S // clang-format on } -inline std::set> get_weights(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) +template +inline std::set> get_weights(T n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) { // clang-format off - return + return { { pick_batch_size(1024, n),1024, 3, 3 }, { pick_batch_size(1024, n),512, 3, 3 }, @@ -139,10 +140,11 @@ inline std::set> get_weights(int n = MIOPEN_TEST_DEFAULT_BATCH_ // clang-format on } -inline std::set> get_immed_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) +template +inline std::set> get_immed_inputs(T n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) { // clang-format off - return + return { { pick_batch_size(32, n), 1, 14, 14 }, { pick_batch_size(256, n), 1, 27, 27 }, @@ -160,10 +162,11 @@ inline std::set> get_immed_inputs(int n = MIOPEN_TEST_DEFAULT_B // clang-format on } -inline std::set> get_immed_weights(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) +template +inline std::set> get_immed_weights(T n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) { // clang-format off - return + return { { pick_batch_size(208, n), 96, 3, 3 }, { pick_batch_size(24, n), 512, 1, 1 }, @@ -182,11 +185,12 @@ inline std::set> get_immed_weights(int n = MIOPEN_TEST_DEFAULT_ // clang-format on } -inline std::set> -get_3d_conv_input_shapes(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) +template +inline std::set> +get_3d_conv_input_shapes(T n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) { // clang-format off - return + return { { pick_batch_size(128, n), 1, 1, 2, 2}, { pick_batch_size(128, n), 64, 1, 1, 1}, @@ -201,11 +205,12 @@ get_3d_conv_input_shapes(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) // clang-format on } -inline std::set> -get_3d_conv_weight_shapes(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) +template +inline std::set> +get_3d_conv_weight_shapes(T n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) { // clang-format off - return + return { { pick_batch_size( 128, n), 1, 1, 1, 1}, { pick_batch_size( 352, n), 128, 1, 1, 1}, @@ -222,11 +227,11 @@ get_3d_conv_weight_shapes(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) // clang-format on } -inline std::set> -get_bn_peract_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) +template +inline std::set> get_bn_peract_inputs(T n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) { // clang-format off - return + return { { pick_batch_size(32, n), 4, 1024,2048}, //Making this much smaller { pick_batch_size(100, n), 3, 32, 32 }, @@ -268,11 +273,11 @@ get_bn_peract_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) // clang-format on } -inline std::set> -get_bn_spatial_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) +template +inline std::set> get_bn_spatial_inputs(T n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) { // clang-format off - return + return { { pick_batch_size(32, n), 4, 1024,2048}, //Making this much smaller { pick_batch_size(32, n), 192, 256, 512 }, @@ -322,11 +327,11 @@ get_bn_spatial_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) // clang-format on } -inline std::set> -get_3d_bn_peract_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) +template +inline std::set> get_3d_bn_peract_inputs(T n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) { // clang-format off - return + return { { pick_batch_size(32, n), 1, 32, 32, 32 }, // 32x32x32 based on VoxNet arch { pick_batch_size(32, n), 1, 14, 14, 14 }, @@ -336,20 +341,20 @@ get_3d_bn_peract_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) { pick_batch_size(256, n), 1, 32, 32, 32 }, // 32x32x32 based on VoxNet arch { pick_batch_size(256, n), 32, 14, 14, 14 }, { pick_batch_size(256, n), 32, 12, 12, 12 }, - { pick_batch_size(256, n), 32, 6, 6, 6 }, + { pick_batch_size(256, n), 32, 6, 6, 6 }, { pick_batch_size(512, n), 1, 32, 32, 32 }, // 32x32x32 based on VoxNet arch { pick_batch_size(512, n), 32, 14, 14, 14 }, { pick_batch_size(512, n), 32, 12, 12, 12 }, - { pick_batch_size(512, n), 32, 6, 6, 6 }, + { pick_batch_size(512, n), 32, 6, 6, 6 }, { pick_batch_size(32, n), 2, 32, 57, 125 }, // Hand-gesture recognition CVPR 2015 paper High Res Net Path { pick_batch_size(32, n), 32, 14, 25, 59 }, { pick_batch_size(32, n), 32, 6, 10, 27 }, - { pick_batch_size(32, n), 32, 4, 6, 11 }, - { pick_batch_size(32, n), 32, 2, 2, 3 }, - { pick_batch_size(32, n), 32, 32, 28, 62 }, // Hand-gesture recognition CVPR 2015 paper Low Res Net Path + { pick_batch_size(32, n), 32, 4, 6, 11 }, + { pick_batch_size(32, n), 32, 2, 2, 3 }, + { pick_batch_size(32, n), 32, 32, 28, 62 }, // Hand-gesture recognition CVPR 2015 paper Low Res Net Path { pick_batch_size(32, n), 32, 14, 12, 29 }, - { pick_batch_size(32, n), 32, 6, 4, 12 }, - { pick_batch_size(32, n), 32, 4, 2, 2 }, + { pick_batch_size(32, n), 32, 6, 4, 12 }, + { pick_batch_size(32, n), 32, 4, 2, 2 }, { pick_batch_size(16, n), 32, 6, 50, 50 }, // Multi-view 3D convnet { pick_batch_size(1, n), 3, 8, 240, 320 }, // 3D convet on video { pick_batch_size(1, n), 3, 16, 240, 320 }, // 3D convet on video @@ -362,11 +367,12 @@ get_3d_bn_peract_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) // clang-format on } -inline std::set> -get_3d_bn_spatial_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) +template +inline std::set> +get_3d_bn_spatial_inputs(T n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) { // clang-format off - return + return { { pick_batch_size(32, n), 1, 32, 32, 32 }, // 32x32x32 based on VoxNet arch { pick_batch_size(32, n), 1, 14, 14, 14 }, @@ -376,20 +382,20 @@ get_3d_bn_spatial_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) { pick_batch_size(256, n), 1, 32, 32, 32 }, // 32x32x32 based on VoxNet arch { pick_batch_size(256, n), 32, 14, 14, 14 }, { pick_batch_size(256, n), 32, 12, 12, 12 }, - { pick_batch_size(256, n), 32, 6, 6, 6 }, + { pick_batch_size(256, n), 32, 6, 6, 6 }, { pick_batch_size(512, n), 1, 32, 32, 32 }, // 32x32x32 based on VoxNet arch { pick_batch_size(512, n), 32, 14, 14, 14 }, { pick_batch_size(512, n), 32, 12, 12, 12 }, - { pick_batch_size(512, n), 32, 6, 6, 6 }, + { pick_batch_size(512, n), 32, 6, 6, 6 }, { pick_batch_size(32, n), 2, 32, 57, 125 }, // Hand-gesture recognition CVPR 2015 paper High Res Net Path { pick_batch_size(32, n), 32, 14, 25, 59 }, { pick_batch_size(32, n), 32, 6, 10, 27 }, - { pick_batch_size(32, n), 32, 4, 6, 11 }, - { pick_batch_size(32, n), 32, 2, 2, 3 }, - { pick_batch_size(32, n), 32, 32, 28, 62 }, // Hand-gesture recognition CVPR 2015 paper Low Res Net Path + { pick_batch_size(32, n), 32, 4, 6, 11 }, + { pick_batch_size(32, n), 32, 2, 2, 3 }, + { pick_batch_size(32, n), 32, 32, 28, 62 }, // Hand-gesture recognition CVPR 2015 paper Low Res Net Path { pick_batch_size(32, n), 32, 14, 12, 29 }, - { pick_batch_size(32, n), 32, 6, 4, 12 }, - { pick_batch_size(32, n), 32, 4, 2, 2 }, + { pick_batch_size(32, n), 32, 6, 4, 12 }, + { pick_batch_size(32, n), 32, 4, 2, 2 }, { pick_batch_size(16, n), 32, 6, 50, 50 }, // Multi-view 3D convnet { pick_batch_size(1, n), 3, 8, 240, 320 }, // 3D convet on video { pick_batch_size(1, n), 3, 16, 240, 320 }, // 3D convet on video @@ -401,7 +407,8 @@ get_3d_bn_spatial_inputs(int n = MIOPEN_TEST_DEFAULT_BATCH_SIZE_FACTOR) // clang-format on } -inline std::vector> get_sub_tensor() +template +inline std::vector> get_sub_tensor() { return {{16, 4, 8, 1, 4}, {2, 4, 8, 8, 4}, @@ -414,11 +421,18 @@ inline std::vector> get_sub_tensor() {4}}; } -inline std::vector> get_tensor_offsets() +template +inline std::vector> get_tensor_offsets() { + static_assert(std::is_signed_v); return {{0, 0}, {0, 2}, {4, 0}, {5, 7}}; } -inline std::vector get_tensor_offset() { return {0, 1, 2, 3, 4, 5}; } +template +inline std::vector get_tensor_offset() +{ + static_assert(std::is_signed_v); + return {0, 1, 2, 3, 4, 5}; +} #endif diff --git a/test/tensor_cast.cpp b/test/tensor_cast.cpp deleted file mode 100644 index a170d9da12..0000000000 --- a/test/tensor_cast.cpp +++ /dev/null @@ -1,204 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2018 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "test.hpp" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "driver.hpp" -#include "get_handle.hpp" -#include "tensor_holder.hpp" -#include "verify.hpp" - -template -struct verify_tensor_cast -{ - miopen::TensorDescriptor srcDesc; - miopen::TensorDescriptor dstDesc; - tensor srcSuper; - tensor dstSuper; - int srcOffset; - int dstOffset; - float alpha; - float max_val; - - verify_tensor_cast(const tensor& psrc_super, - const tensor& pdst_super, - const miopen::TensorDescriptor& psd, - const miopen::TensorDescriptor& pdd, - std::vector offsets, - const float palpha, - const float pmax_val) - { - srcDesc = psd; - dstDesc = pdd; - srcSuper = psrc_super; - dstSuper = pdst_super; - srcOffset = offsets[0]; - dstOffset = offsets[1]; - alpha = palpha; - max_val = pmax_val; - } - - void tensor_cast_for_loop(tensor& dstSuperCpu, - int src_offset_index, - int dst_offset_index, - int dim) const - { - auto src_stride = srcDesc.GetStrides()[dim]; - auto dst_stride = dstDesc.GetStrides()[dim]; - - for(int idx = 0; idx < srcDesc.GetLengths()[dim]; idx++) - { - std::size_t src_super_index = - ((dim == 0) ? srcOffset : 0) + src_offset_index + src_stride * idx; - std::size_t dst_super_index = - ((dim == 0) ? dstOffset : 0) + dst_offset_index + dst_stride * idx; - - if(dim < (srcDesc.GetLengths().size() - 1)) - { - tensor_cast_for_loop(dstSuperCpu, src_super_index, dst_super_index, dim + 1); - } - if(dst_super_index < dstSuperCpu.desc.GetElementSpace() && - src_super_index < srcSuper.desc.GetElementSpace()) - { - float temp_val = float(srcSuper[src_super_index]) * alpha; - dstSuperCpu[dst_super_index] = T(temp_val >= max_val ? max_val : temp_val); - } - } - } - - tensor cpu() const - { - tensor dstSuperCpu = dstSuper; - - tensor_cast_for_loop(dstSuperCpu, 0, 0, 0); - - return dstSuperCpu; - } - - tensor gpu() const - { - tensor dstSuperGpu = dstSuper; - - auto&& handle = get_handle(); - auto dstSuper_dev = handle.Write(dstSuperGpu.data); - auto srcSuper_dev = handle.Write(srcSuper.data); - - miopen::CastTensor(handle, - &alpha, - true, - srcDesc, - srcSuper_dev.get(), - dstDesc, - dstSuper_dev.get(), - srcOffset, - dstOffset); - - dstSuperGpu.data = handle.Read(dstSuper_dev, dstSuperGpu.data.size()); - - return dstSuperGpu; - } - - void fail(float = 0) - { - std::cout << "Tensor Cast: " << std::endl; - std::cout << "src super-tensor: " << srcSuper.desc.ToString() << std::endl; - std::cout << "dst super-tensor: " << dstSuper.desc.ToString() << std::endl; - std::cout << "src sub-tensor: " << srcDesc.ToString() << std::endl; - std::cout << "dst sub-tensor: " << dstDesc.ToString() << std::endl; - } -}; - -template -struct tensor_cast_driver : test_driver -{ - tensor srcSuper; - tensor dstSuper; - std::vector srcSuperLens; - std::vector dstSuperLens; - float alpha = 1.0; - float max_val = 0.; - - miopen::TensorDescriptor srcDesc; - miopen::TensorDescriptor dstDesc; - std::vector castLens; - std::vector offsets; - - tensor_cast_driver() - { - disabled_cache = true; - std::vector src_lens = {32, 16, 32, 16, 16}; - std::vector dst_lens = {32, 32, 16, 16, 16}; - - add(srcSuperLens, "srcSuperLens", generate_data({src_lens}, src_lens)); - add(dstSuperLens, "dstSuperLens", generate_data({dst_lens}, dst_lens)); - add(castLens, "castLens", generate_data(get_sub_tensor(), {32, 8, 10})); - add(offsets, "offsets", generate_data(get_tensor_offsets(), {7, 11})); - add(alpha, "alpha", generate_data({1.0 / 127 / 127, 1.0 / 127, 127.0, 1.0})); - } - - void run() - { - uint64_t max_value = - miopen_type{} == miopenHalf ? 5 : (miopen_type{} == miopenInt8 ? 126 : 32767); - max_val = miopen_type{} == miopenHalf ? 65504.0 - : miopen_type{} == miopenInt8 ? 127.0 - : miopen_type{} == miopenInt32 ? 2147483647.0 - : miopen_type{} == miopenBFloat16 ? 0x7F7F - : 3.402823466e+38F; - - srcSuper = tensor{srcSuperLens}.generate(tensor_elem_gen_integer{max_value}); - dstSuper = tensor{dstSuperLens}.generate(tensor_elem_gen_integer{max_value}); - - std::vector srcSuperStrides = srcSuper.desc.GetStrides(); - std::vector dstSuperStrides = dstSuper.desc.GetStrides(); - std::vector src_super_strides(srcSuperStrides.begin() + - (srcSuper.desc.GetNumDims() - castLens.size()), - srcSuperStrides.end()); - std::vector dst_super_strides(dstSuperStrides.begin() + - (dstSuper.desc.GetNumDims() - castLens.size()), - dstSuperStrides.end()); - - srcDesc = miopen::TensorDescriptor(miopenInt32, castLens, src_super_strides); - dstDesc = miopen::TensorDescriptor(miopen_type{}, castLens, dst_super_strides); - - if(srcDesc.GetLengths().size() == dstDesc.GetLengths().size()) - { - verify_equals(verify_tensor_cast{ - srcSuper, dstSuper, srcDesc, dstDesc, offsets, alpha, max_val}); - } - } -}; - -int main(int argc, const char* argv[]) { test_drive(argc, argv); } diff --git a/test/tensor_copy.cpp b/test/tensor_copy.cpp deleted file mode 100644 index 9cdf762b52..0000000000 --- a/test/tensor_copy.cpp +++ /dev/null @@ -1,182 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2017 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "test.hpp" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "driver.hpp" -#include "get_handle.hpp" -#include "tensor_holder.hpp" -#include "verify.hpp" - -template -struct verify_tensor_copy -{ - miopen::TensorDescriptor srcDesc; - miopen::TensorDescriptor dstDesc; - tensor srcSuper; - tensor dstSuper; - int srcOffset; - int dstOffset; - - verify_tensor_copy(const tensor& psrc_super, - const tensor& pdst_super, - const miopen::TensorDescriptor& psd, - const miopen::TensorDescriptor& pdd, - std::vector offsets) - { - srcDesc = psd; - dstDesc = pdd; - srcSuper = psrc_super; - dstSuper = pdst_super; - srcOffset = offsets[0]; - dstOffset = offsets[1]; - } - - void tensor_copy_for_loop(tensor& dstSuperCpu, - int src_offset_index, - int dst_offset_index, - int dim) const - { - auto src_stride = srcDesc.GetStrides()[dim]; - auto dst_stride = dstDesc.GetStrides()[dim]; - - for(int idx = 0; idx < srcDesc.GetLengths()[dim]; idx++) - { - std::size_t src_super_index = - ((dim == 0) ? srcOffset : 0) + src_offset_index + src_stride * idx; - std::size_t dst_super_index = - ((dim == 0) ? dstOffset : 0) + dst_offset_index + dst_stride * idx; - - if(dim < (srcDesc.GetLengths().size() - 1)) - { - tensor_copy_for_loop(dstSuperCpu, src_super_index, dst_super_index, dim + 1); - } - if(dst_super_index < dstSuperCpu.desc.GetElementSpace() && - src_super_index < srcSuper.desc.GetElementSpace()) - { - dstSuperCpu[dst_super_index] = srcSuper[src_super_index]; - } - } - } - - tensor cpu() const - { - tensor dstSuperCpu = dstSuper; - - tensor_copy_for_loop(dstSuperCpu, 0, 0, 0); - - return dstSuperCpu; - } - - tensor gpu() const - { - tensor dstSuperGpu = dstSuper; - - auto&& handle = get_handle(); - auto dstSuper_dev = handle.Write(dstSuperGpu.data); - auto srcSuper_dev = handle.Write(srcSuper.data); - - miopen::CopyTensor( - handle, srcDesc, srcSuper_dev.get(), dstDesc, dstSuper_dev.get(), srcOffset, dstOffset); - - dstSuperGpu.data = handle.Read(dstSuper_dev, dstSuperGpu.data.size()); - - return dstSuperGpu; - } - - void fail(float = 0) - { - std::cout << "Tensor Copy: " << std::endl; - std::cout << "src super-tensor: " << srcSuper.desc.ToString() << std::endl; - std::cout << "dst super-tensor: " << dstSuper.desc.ToString() << std::endl; - std::cout << "src sub-tensor: " << srcDesc.ToString() << std::endl; - std::cout << "dst sub-tensor: " << dstDesc.ToString() << std::endl; - } -}; - -template -struct tensor_copy_driver : test_driver -{ - tensor srcSuper; - tensor dstSuper; - std::vector srcSuperLens; - std::vector dstSuperLens; - - miopen::TensorDescriptor srcDesc; - miopen::TensorDescriptor dstDesc; - std::vector copyLens; - std::vector offsets; - - tensor_copy_driver() - { - disabled_cache = true; - std::vector src_lens = {32, 16, 32, 16, 16}; - std::vector dst_lens = {32, 32, 16, 16, 16}; - - add(srcSuperLens, "srcSuperLens", generate_data({src_lens}, src_lens)); - add(dstSuperLens, "dstSuperLens", generate_data({dst_lens}, dst_lens)); - add(copyLens, "copyLens", generate_data(get_sub_tensor(), {32, 8, 10})); - add(offsets, "offsets", generate_data(get_tensor_offsets(), {7, 11})); - } - - void run() - { - uint64_t max_value = miopen_type{} == miopenHalf ? 5 - : miopen_type{} == miopenInt8 ? 127 - : 17; - - srcSuper = tensor{srcSuperLens}.generate(tensor_elem_gen_integer{max_value}); - dstSuper = tensor{dstSuperLens}.generate(tensor_elem_gen_integer{max_value}); - - std::vector srcSuperStrides = srcSuper.desc.GetStrides(); - std::vector dstSuperStrides = dstSuper.desc.GetStrides(); - std::vector src_super_strides(srcSuperStrides.begin() + - (srcSuper.desc.GetNumDims() - copyLens.size()), - srcSuperStrides.end()); - std::vector dst_super_strides(dstSuperStrides.begin() + - (dstSuper.desc.GetNumDims() - copyLens.size()), - dstSuperStrides.end()); - - srcDesc = miopen::TensorDescriptor(this->type, copyLens, src_super_strides); - dstDesc = miopen::TensorDescriptor(this->type, copyLens, dst_super_strides); - - if(srcDesc.GetLengths().size() == dstDesc.GetLengths().size()) - { - verify_equals(verify_tensor_copy{srcSuper, dstSuper, srcDesc, dstDesc, offsets}); - } - } -}; - -int main(int argc, const char* argv[]) { test_drive(argc, argv); } diff --git a/test/tensor_holder.hpp b/test/tensor_holder.hpp index e1b03880b8..ff9566fe6c 100644 --- a/test/tensor_holder.hpp +++ b/test/tensor_holder.hpp @@ -145,6 +145,7 @@ struct miopen_type : std::integral_constant struct tensor { + using value_type = T; miopen::TensorDescriptor desc; std::vector data; diff --git a/test/tensor_scale.cpp b/test/tensor_scale.cpp deleted file mode 100644 index ff05bd9d1a..0000000000 --- a/test/tensor_scale.cpp +++ /dev/null @@ -1,139 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2017 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "test.hpp" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "driver.hpp" -#include "get_handle.hpp" -#include "tensor_holder.hpp" -#include "verify.hpp" -#include "tensor_util.hpp" - -template -struct scale_data_t -{ - const T alpha; - - void operator()(T& r_data) const { r_data *= alpha; } -}; - -template -struct verify_tensor_scale -{ - miopen::TensorDescriptor subDesc; - tensor super; - int offset; - T alpha; - - verify_tensor_scale(const tensor& rSuper, - const miopen::TensorDescriptor& rSubDesc, - const int offsetIn, - const T alphaIn) - { - subDesc = rSubDesc; - super = rSuper; - offset = offsetIn; - alpha = alphaIn; - } - - tensor cpu() const - { - tensor superCpu = super; - - const scale_data_t data_operator = {alpha}; - - operate_over_subtensor(data_operator, superCpu, subDesc, offset); - - return superCpu; - } - - tensor gpu() const - { - tensor superGpu = super; - - auto&& handle = get_handle(); - auto super_dev = handle.Write(superGpu.data); - - miopen::ScaleTensor(handle, subDesc, super_dev.get(), &alpha, offset); - - superGpu.data = handle.Read(super_dev, superGpu.data.size()); - - return superGpu; - } - - void fail(float = 0) - { - std::cout << "Tensor Set: " << std::endl; - std::cout << "super-tensor: " << super.desc.ToString() << std::endl; - std::cout << "sub-tensor: " << subDesc.ToString() << std::endl; - } -}; - -template -struct tensor_scale_driver : test_driver -{ - tensor super; - std::vector superLens; - miopen::TensorDescriptor subDesc; - std::vector subLens; - int offset = 0; - - tensor_scale_driver() - { - disabled_cache = true; - std::vector lens = {32, 32, 16, 16, 16}; - - add(superLens, "superLens", generate_data({lens}, lens)); - add(subLens, "subLens", generate_data(get_sub_tensor(), {32, 8, 10})); - add(offset, "offset", generate_data(get_tensor_offset(), 7)); - } - - void run() - { - uint64_t max_value = miopen_type{} == miopenHalf ? 5 : 17; - - super = tensor{superLens}.generate(tensor_elem_gen_integer{max_value}); - - std::vector superStrides = super.desc.GetStrides(); - std::vector subStrides( - superStrides.begin() + (super.desc.GetNumDims() - subLens.size()), superStrides.end()); - - subDesc = miopen::TensorDescriptor(this->type, subLens, subStrides); - - verify_equals(verify_tensor_scale{super, subDesc, offset, T(2.048)}); - } -}; - -int main(int argc, const char* argv[]) { test_drive(argc, argv); } diff --git a/test/tensor_set.cpp b/test/tensor_set.cpp deleted file mode 100644 index 5128617338..0000000000 --- a/test/tensor_set.cpp +++ /dev/null @@ -1,141 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2017 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "test.hpp" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "driver.hpp" -#include "get_handle.hpp" -#include "tensor_holder.hpp" -#include "verify.hpp" -#include "tensor_util.hpp" - -template -struct set_data_t -{ - const T alpha; - - void operator()(T& r_data) const { r_data = alpha; } -}; - -template -struct verify_tensor_set -{ - miopen::TensorDescriptor subDesc; - tensor super; - int offset; - T alpha; - - verify_tensor_set(const tensor& rSuper, - const miopen::TensorDescriptor& rSubDesc, - const int offsetIn, - const T alphaIn) - { - subDesc = rSubDesc; - super = rSuper; - offset = offsetIn; - alpha = alphaIn; - } - - tensor cpu() const - { - tensor superCpu = super; - - const set_data_t data_operator = {alpha}; - - operate_over_subtensor(data_operator, superCpu, subDesc, offset); - - return superCpu; - } - - tensor gpu() const - { - tensor superGpu = super; - - auto&& handle = get_handle(); - auto super_dev = handle.Write(superGpu.data); - - miopen::SetTensor(handle, subDesc, super_dev.get(), &alpha, offset); - - superGpu.data = handle.Read(super_dev, superGpu.data.size()); - - return superGpu; - } - - void fail(float = 0) - { - std::cout << "Tensor Set: " << std::endl; - std::cout << "super-tensor: " << super.desc.ToString() << std::endl; - std::cout << "sub-tensor: " << subDesc.ToString() << std::endl; - } -}; - -template -struct tensor_set_driver : test_driver -{ - tensor super; - std::vector superLens; - miopen::TensorDescriptor subDesc; - std::vector subLens; - int offset = 0; - - tensor_set_driver() - { - disabled_cache = true; - std::vector lens = {32, 32, 16, 16, 16}; - - add(superLens, "superLens", generate_data({lens}, lens)); - add(subLens, "subLens", generate_data(get_sub_tensor(), {32, 8, 10})); - add(offset, "offset", generate_data(get_tensor_offset(), 7)); - } - - void run() - { - uint64_t max_value = miopen_type{} == miopenHalf ? 5 - : miopen_type{} == miopenInt8 ? 127 - : 17; - - super = tensor{superLens}.generate(tensor_elem_gen_integer{max_value}); - - std::vector superStrides = super.desc.GetStrides(); - std::vector subStrides( - superStrides.begin() + (super.desc.GetNumDims() - subLens.size()), superStrides.end()); - - subDesc = miopen::TensorDescriptor(this->type, subLens, subStrides); - - verify_equals(verify_tensor_set{super, subDesc, offset, T(1.111)}); - } -}; - -int main(int argc, const char* argv[]) { test_drive(argc, argv); } diff --git a/test/tensor_transform.cpp b/test/tensor_transform.cpp index 90c101b9b0..c67491b5e6 100644 --- a/test/tensor_transform.cpp +++ b/test/tensor_transform.cpp @@ -229,58 +229,22 @@ struct verify_tensor_transform_scale beta = betaIn; } - static T multadd_elem(T aelem, T acte, T belem, T bcte) - { - return ((acte * aelem) + (bcte * belem)); - } - void tensor_multadd_for_loop(tensor& superCpu_src, - tensor& superCpu_dst, - int src_offset_index, - int dst_offset_index, - T acte, - T bcte, - int dim) const - { - auto src_stride = subDesc_src.GetStrides()[dim]; - auto dst_stride = subDesc_dst.GetStrides()[dim]; - size_t srcOffset = src_offset; - size_t dstOffset = dst_offset; - - for(int idx = 0; idx < subDesc_src.GetLengths()[dim]; idx++) - { - std::size_t src_super_index = - ((dim == 0) ? srcOffset : 0) + src_offset_index + src_stride * idx; - std::size_t dst_super_index = - ((dim == 0) ? dstOffset : 0) + dst_offset_index + dst_stride * idx; - - if(dim < (subDesc_src.GetLengths().size() - 1)) - { - tensor_multadd_for_loop(superCpu_src, - superCpu_dst, - src_super_index, - dst_super_index, - acte, - bcte, - dim + 1); - } - else if(dst_super_index < superCpu_dst.desc.GetElementSpace() && - src_super_index < superCpu_src.desc.GetElementSpace()) - { - superCpu_dst[dst_super_index] = multadd_elem(T(superCpu_src[src_super_index]), - alpha, - T(superCpu_dst[dst_super_index]), - beta); - } - } - } - tensor cpu() const { tensor superCpu_src = super_src; tensor superCpu_dst = super_dst; - tensor_multadd_for_loop(superCpu_src, superCpu_dst, 0, 0, alpha, beta, 0); + operate_over_subtensor( + [a = alpha, b = beta](auto& dst, auto src) { + dst = static_cast(src) * a + static_cast(dst) * b; + }, + superCpu_dst, + superCpu_src, + subDesc_dst, + subDesc_src, + dst_offset, + src_offset); #if(MIO_TRANSFORM_DEBUG) printf("\n CPU: \n"); diff --git a/test/tensor_util.hpp b/test/tensor_util.hpp index d14f0c806d..244a7f4952 100644 --- a/test/tensor_util.hpp +++ b/test/tensor_util.hpp @@ -27,6 +27,8 @@ #ifndef GUARD_TENSOR_UTIL_HPP #define GUARD_TENSOR_UTIL_HPP +#include + #include #include #include @@ -34,44 +36,133 @@ namespace fs = miopen::fs; -// loop over sub-tensor, and operate on each data -template class data_operator_t> -void operate_over_subtensor(const data_operator_t& r_data_operator, - tensor& rSuperTensor, - const miopen::TensorDescriptor& rSubDesc, - const int offset) +// unary operation +template +void operate_over_subtensor(DataOp&& dataOp, + Container& srcSuperTensor, + const miopen::TensorDescriptor& srcSubDesc, + const int64_t srcOffset) { - operate_over_subtensor_impl(r_data_operator, rSuperTensor, rSubDesc, 0, offset); + const auto& srcStrides = srcSubDesc.GetStrides(); + const auto& srcLens = srcSubDesc.GetLengths(); + + auto operate_over_subtensor_impl = + [&, dataOp, max_dim = srcLens.size() - 1]( + auto&& self, const size_t current_dim, const int64_t srcOff) -> void { + const auto current_stride = srcStrides[current_dim]; + + int64_t index = srcOff; + + for(size_t i = 0; i < srcLens[current_dim]; ++i) + { + if(current_dim < max_dim) + { + self(self, current_dim + 1, index); + } + else + { + dataOp(srcSuperTensor[index]); + } + index += current_stride; + } + }; + operate_over_subtensor_impl(operate_over_subtensor_impl, 0, srcOffset); } -// loop over part of sub-tensor (dimensions lower than "current_dim"), and operate on -// each data -template class data_operator_t> -void operate_over_subtensor_impl(const data_operator_t& r_data_operator, - tensor& rSuperTensor, - const miopen::TensorDescriptor& rSubDesc, - const unsigned current_dim, - const int offset) +// binary operation, it implies cast operation +template +void operate_over_subtensor(DataOp&& dataOp, + DstContainer& dstSuperTensor, + SrcContainer& srcSuperTensor, + const miopen::TensorDescriptor& dstSubDesc, + const miopen::TensorDescriptor& srcSubDesc, + const int64_t dstOffset, + const int64_t srcOffset) { - auto max_dim = static_cast(rSubDesc.GetLengths().size() - 1); - auto current_stride = static_cast(rSubDesc.GetStrides()[current_dim]); + const auto& dstStrides = dstSubDesc.GetStrides(); + const auto& srcStrides = srcSubDesc.GetStrides(); - int index = offset; + const auto& srcLens = srcSubDesc.GetLengths(); - for(int i = 0; i < rSubDesc.GetLengths()[current_dim]; ++i) - { - if(current_dim == max_dim) + auto operate_over_subtensor_impl = + [&, dataOp, max_dim = srcLens.size() - 1](auto&& self, + const size_t current_dim, + const int64_t dstOff, + const int64_t srcOff) -> void { + const auto dstStride = dstStrides[current_dim]; + const auto srcStride = srcStrides[current_dim]; + + int64_t dstIdx = dstOff; + int64_t srcIdx = srcOff; + + for(size_t i = 0; i < srcLens[current_dim]; ++i) { - r_data_operator(rSuperTensor[index]); + if(current_dim < max_dim) + { + self(self, current_dim + 1, dstIdx, srcIdx); + } + else + { + dataOp(dstSuperTensor[dstIdx], srcSuperTensor[srcIdx]); + } + dstIdx += dstStride; + srcIdx += srcStride; } - else + }; + operate_over_subtensor_impl(operate_over_subtensor_impl, 0, dstOffset, srcOffset); +} + +// ternary operation, it implies broadcasting for src2 +template +void operate_over_subtensor(DataOp&& dataOp, + Container& dstSuperTensor, + const Container& src1SuperTensor, + const Container& src2SuperTensor, + const miopen::TensorDescriptor& dstSubDesc, + const miopen::TensorDescriptor& src1SubDesc, + const miopen::TensorDescriptor& src2SubDesc, + const int64_t dstOffset, + const int64_t src1Offset, + const int64_t src2Offset) +{ + const auto& dstStrides = dstSubDesc.GetStrides(); + const auto& src1Strides = src1SubDesc.GetStrides(); + const auto& src2Strides = src2SubDesc.GetStrides(); + + const auto& src1Lens = src1SubDesc.GetLengths(); + const auto& src2Lens = src2SubDesc.GetLengths(); + + auto operate_over_subtensor_impl = + [&, dataOp, max_dim = src1Lens.size() - 1](auto&& self, + const size_t current_dim, + const int64_t dstOff, + const int64_t src1Off, + const int64_t src2Off) -> void { + const auto dstStride = dstStrides[current_dim]; + const auto src1Stride = src1Strides[current_dim]; + const auto src2Stride = src2Strides[current_dim]; + const bool squashed = src1Lens[current_dim] != src2Lens[current_dim]; + + int64_t dstIdx = dstOff; + int64_t src1Idx = src1Off; + int64_t src2Idx = src2Off; + + for(size_t i = 0; i < src1Lens[current_dim]; ++i) { - operate_over_subtensor_impl( - r_data_operator, rSuperTensor, rSubDesc, current_dim + 1, index); + if(current_dim < max_dim) + { + self(self, current_dim + 1, dstIdx, src1Idx, src2Idx); + } + else + { + dataOp(dstSuperTensor[dstIdx], src1SuperTensor[src1Idx], src2SuperTensor[src2Idx]); + } + dstIdx += dstStride; + src1Idx += src1Stride; + src2Idx += squashed ? 0 : src2Stride; } - - index += current_stride; - } + }; + operate_over_subtensor_impl(operate_over_subtensor_impl, 0, dstOffset, src1Offset, src2Offset); } template