Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom ops Rotary #738

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "cuda/mul_sigmoid.h"
#include "cuda/negxplus1.h"
#include "cuda/replace_zero.h"
#include "cuda/rotary.h"
#include "cuda/scatter_nd_of_shape.h"
#include "cuda/transpose_cast.h"
#endif
Expand Down Expand Up @@ -36,6 +37,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<float>),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<float>),
CustomCudaStructV2("Rotary", contrib::Rotary<float>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
#if ORT_API_VERSION >= 16

Expand All @@ -48,6 +50,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<ortc::MFloat16>),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<ortc::MFloat16>),
CustomCudaStructV2("Rotary", contrib::Rotary<ortc::MFloat16>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
Expand Down
83 changes: 83 additions & 0 deletions operators/cuda/rotary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"
#include "rotary_impl.cuh"
#include "ortx_common.h"

namespace contrib {

/**
* Y = Rotary(X) is equivalent to if side == LEFT:
*
* N = X.shape[-1]
* Y = X.copy()
* Y[...,:N/2] = X[...,N/2:]
* Y[...,N/2:] = -X[...,:N/2]
*
* And the opposite if side == RIGHT.
xadupre marked this conversation as resolved.
Show resolved Hide resolved
*/
template <typename T>
struct Rotary {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& dict) {
std::string empty;
std::string side = dict.TryToGetAttributeWithDefault("side", empty);
xadupre marked this conversation as resolved.
Show resolved Hide resolved
if (side == "left") {
side_ = RotarySide::LEFT;
}
else if (side == "right") {
side_ = RotarySide::RIGHT;
}
else {
return {kOrtxErrorInvalidArgument, "side must be 'left' or 'right'."};
}

return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
xadupre marked this conversation as resolved.
Show resolved Hide resolved
const ortc::Tensor<T>& input,
const ortc::Tensor<int64_t>& split,
wschin marked this conversation as resolved.
Show resolved Hide resolved
ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
auto input_shape = input.Shape();
T* output_data = output.Allocate(input_shape);
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return {};
}

auto shape_split = split.Shape();
if (shape_split.size() != 1 || shape_split[0] != 2) {
wschin marked this conversation as resolved.
Show resolved Hide resolved
return {kOrtxErrorInvalidArgument, "Rotary only works when there are two sides."};
}
const int64_t* split_data = split.Data();
if (split_data[0] != split_data[1]) {
return {kOrtxErrorInvalidArgument, "Only equal split are allowed."};
xadupre marked this conversation as resolved.
Show resolved Hide resolved
}
if (split_data[0] * 2 != input_shape[input_shape.size()-1]) {
return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."};
xadupre marked this conversation as resolved.
Show resolved Hide resolved
}

LaunchRotaryKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
static_cast<int>(input_shape[input_shape.size()-1]),
xadupre marked this conversation as resolved.
Show resolved Hide resolved
input_data,
split_data,
output_data,
side_);
return {};
}

static OrtMemType GetInputMemoryType(size_t input_index) {
if (input_index == 1) // split
return OrtMemType::OrtMemTypeCPUInput;
return OrtMemType::OrtMemTypeDefault;
}

private:
RotarySide side_;
};

} // namespace contrib
81 changes: 81 additions & 0 deletions operators/cuda/rotary_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "device_prop.cuh"
#include "utils.cuh"
#include "rotary_impl.cuh"
#include "cuda_type.h"

#ifndef CUDA_LONG
#define CUDA_LONG int32_t
#endif

using namespace Ort::Custom;

template <typename T> __device__ __inline__ T _neg(const T x) { return -x; }

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half _neg(const half x) {
return __float2half(-__half2float(x));
}
#endif

template <typename T, RotarySide side>
xadupre marked this conversation as resolved.
Show resolved Hide resolved
__global__ void RotaryKernel(T *output_data, const T *input_data, CUDA_LONG half_N, CUDA_LONG half_stride) {
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= half_N)
return;
CUDA_LONG last = id % half_stride;
id = (id - last) * 2 + last;
if (side == RotarySide::RIGHT) {
output_data[id + half_stride] = input_data[id];
output_data[id] = _neg(input_data[id + half_stride]);
} else {
output_data[id + half_stride] = _neg(input_data[id]);
output_data[id] = input_data[id + half_stride];
}
}

template <typename T>
cudaError_t _LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
const T* input_data, const int64_t* /* split_data */, T* output_data, RotarySide side) {
if (input_length == 0)
return cudaGetLastError();
using TT = typename contrib::CudaT<T>::MappedType;

CUDA_LONG N = static_cast<CUDA_LONG>(input_length);
CUDA_LONG stride = static_cast<CUDA_LONG>(last_dim);

const int num_threads_per_block = 256;
const int num_elements_per_thread =
(N / 2 + num_threads_per_block - 1) / num_threads_per_block;

switch (side) {
case RotarySide::LEFT:
RotaryKernel<TT, RotarySide::LEFT>
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(reinterpret_cast<TT*>(output_data),
reinterpret_cast<const TT*>(input_data),
N / 2, stride / 2);
break;
case RotarySide::RIGHT:
RotaryKernel<TT, RotarySide::RIGHT>
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(reinterpret_cast<TT*>(output_data),
reinterpret_cast<const TT*>(input_data),
N / 2, stride / 2);
break;
}
return cudaGetLastError();
}

template <>
cudaError_t LaunchRotaryKernel<float>(cudaStream_t stream, int input_length, int last_dim,
const float* input_data, const int64_t* split_data, float* output_data, RotarySide side) {
return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side);
}

template <>
cudaError_t LaunchRotaryKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, int last_dim,
const ortc::MFloat16* input_data, const int64_t* split_data,
ortc::MFloat16* output_data, RotarySide side) {
return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side);
}
15 changes: 15 additions & 0 deletions operators/cuda/rotary_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>

enum class RotarySide : int {
LEFT = 1,
RIGHT = 2,
};

template <typename T>
cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
const T* input_data, const int64_t* split_data, T* output_data, RotarySide side);
61 changes: 61 additions & 0 deletions test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,67 @@ def test_masked_scatternd_of_shape_standalone_cuda_big(self):
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, True)
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, True)

def _rotary_cuda(self, itype, side, input_shape=(3, 2, 3, 4)):
model2 = helper.make_model(
helper.make_graph(
[
helper.make_node(
"Rotary",
["X", "splits"],
["Y"],
domain="ai.onnx.contrib",
side=side,
)
],
"nd",
[
helper.make_tensor_value_info("X", itype, [None, None, None, None]),
helper.make_tensor_value_info("splits", TensorProto.INT64, [2]),
],
[helper.make_tensor_value_info("Y", itype, [None, None, None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
x = (np.arange(np.prod(input_shape)) + 1).reshape(input_shape).astype(dtype)
splits = np.array([x.shape[-1] // 2, x.shape[-1] // 2], dtype=np.int64)

expected = x.copy()
half = x.shape[-1] // 2
if side == "left":
expected[:, :, :, :half] = x[:, :, :, half:]
expected[:, :, :, half:] = -x[:, :, :, :half]
else:
expected[:, :, :, :half] = -x[:, :, :, half:]
expected[:, :, :, half:] = x[:, :, :, :half]

feeds = dict(X=x, splits=splits)
opts = _ort.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds)[0]
assert_almost_equal(expected, got)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_rotary_cuda(self):
self._rotary_cuda(TensorProto.FLOAT, "left")
self._rotary_cuda(TensorProto.FLOAT, "right")
self._rotary_cuda(TensorProto.FLOAT16, "left")
self._rotary_cuda(TensorProto.FLOAT16, "right")

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_bigger_rotary_cuda(self):
sh = (2, 2, 1024, 8)
self._rotary_cuda(TensorProto.FLOAT, "left", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT, "right", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT16, "left", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT16, "right", input_shape=sh)

def _transpose_cast_cuda(self, itype):
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16
Expand Down
Loading