-
Notifications
You must be signed in to change notification settings - Fork 9
/
xfc_gemm_kernel.cu
247 lines (218 loc) · 12.3 KB
/
xfc_gemm_kernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
#include <vector>
#include <iostream>
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#include "cutlass/util/device_memory.h"
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}
void xfc_gemm_cuda(
torch::Tensor mat_in1,
torch::Tensor mat_in2,
torch::Tensor mat_out,
float alpha_in,
float beta_in,
bool apply_sigmoid) {
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm70;
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; // <- ??
// Number of pipelines you want to use
constexpr int NumStages = 2;
ElementComputeEpilogue alpha = ElementComputeEpilogue(alpha_in);
ElementComputeEpilogue beta = ElementComputeEpilogue(beta_in);
int M = mat_out.size(0);
int N = mat_out.size(1);
int K = mat_in2.size(0);
/*
std::cout << "matin1: " << mat_in1.size(0) << ", " <<mat_in1.size(1) << ", " << mat_in1.stride(0) << std::endl;
std::cout << "matin2: " << mat_in2.size(0) << ", " <<mat_in2.size(1) << ", " << mat_in2.stride(0) << std::endl;
std::cout << "matOut: " << mat_out.size(0) << ", " <<mat_out.size(1) << ", " << mat_out.stride(0) << std::endl;
std::cout << "M,N,K: " << mat_out.size(0) << ", " <<mat_out.size(1) << ", " <<mat_in2.size(0) << std::endl;
*/
if (apply_sigmoid) {
int split_k_slices = 1;
using ElementOutputH = cutlass::half_t; // <- data type of elements in output matrix D
using EpilogueOpSH = cutlass::epilogue::thread::LinearCombinationSigmoid<
ElementOutputH, 128 / cutlass::sizeof_bits<ElementOutputH>::value, ElementAccumulator,
ElementComputeEpilogue>;
if ((mat_in1.stride(1) == 1) && (mat_in2.stride(0) == 1)) {
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
using ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<128, 64, 32>; // <- threadblock tile, e.g. M = 128, N = 128, K = 32
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 32>; // <- warp tile,e.b. M = 64, N = 64, K = 32
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutputH,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOpSH,
SwizzleThreadBlock,
NumStages>;
Gemm::Arguments args({M , N, K}, // Gemm Problem dimensions
{static_cast<const cutlass::half_t *>(mat_in1.data_ptr()),mat_in1.stride(0)}, // <- reference to matrix A on device
{static_cast<const cutlass::half_t *>(mat_in2.data_ptr()),mat_in2.stride(1)}, // <- reference to matrix B on device
{static_cast<cutlass::half_t *>(mat_out.data_ptr()),mat_out.stride(0)}, // <- reference to matrix C on device
{static_cast<cutlass::half_t *>(mat_out.data_ptr()),mat_out.stride(0)}, // <- reference to matrix C on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices); // <- k-dimension split factor
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
cutlass::Status status = gemm_op(args);
CUTLASS_CHECK(status);
}
else {
std::cout << "Only rowXcolumn supported for applysigmoid=True. TBD" << std::endl;
exit(-1);
}
} // end if apply_sigmoid
else {
using ElementOutputF = float; // <- data type of elements in output matrix D
using EpilogueOpDF = cutlass::epilogue::thread::LinearCombination<
ElementOutputF, 128 / cutlass::sizeof_bits<ElementOutputF>::value, ElementAccumulator,
ElementComputeEpilogue>;
if ((mat_in1.stride(0) == 1) && (mat_in2.stride(1) == 1)) {
// Gradient with respect to weights
int split_k_slices = 1;
using LayoutInputA = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::RowMajor;
using LayoutOutput = cutlass::layout::RowMajor;
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 64, 32>;
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 32>;
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutputF,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOpDF,
SwizzleThreadBlock,
NumStages>;
Gemm::Arguments args({M , N, K}, // Gemm Problem dimensions
{static_cast<const cutlass::half_t *>(mat_in1.data_ptr()),mat_in1.stride(1)}, // <- reference to matrix A on device
{static_cast<const cutlass::half_t *>(mat_in2.data_ptr()),mat_in2.stride(0)}, // <- reference to matrix B on device
{static_cast<float *>(mat_out.data_ptr()),mat_out.stride(0)}, // <- reference to matrix C on device
{static_cast<float *>(mat_out.data_ptr()),mat_out.stride(0)}, // <- reference to matrix C on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices); // <- k-dimension split factor
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
cutlass::Status status = gemm_op(args);
CUTLASS_CHECK(status);
}
else if ((mat_in1.stride(1) == 1) && (mat_in2.stride(1) == 1)) {
//Gradient with respect to input
/*
int split_k_slices = 1;
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::RowMajor;
using LayoutOutput = cutlass::layout::RowMajor;
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<64, 128, 32>;
using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 64, 32>;
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutputF,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOpDF,
SwizzleThreadBlock,
NumStages>;
Gemm::Arguments args({M , N, K}, // Gemm Problem dimensions
{static_cast<const cutlass::half_t *>(mat_in1.data_ptr()),mat_in1.stride(0)}, // <- reference to trans of matrix B on device
{static_cast<const cutlass::half_t *>(mat_in2.data_ptr()),mat_in2.stride(0)}, // <- reference to trans of matrix A on device
{static_cast<float *>(mat_out.data_ptr()),mat_out.stride(0)}, // <- reference to trans of matrix C on device
{static_cast<float *>(mat_out.data_ptr()),mat_out.stride(0)}, // <- reference to trans of matrix C on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices); // <- k-dimension split factor
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
cutlass::Status status = gemm_op(args);
CUTLASS_CHECK(status);
*/
// Gradient with respect to inputs, K is large
int split_k_slices = 128;
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::RowMajor;
using LayoutOutput = cutlass::layout::RowMajor;
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 64, 32>;
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 32>;
using Gemm = cutlass::gemm::device::GemmSplitKParallel<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutputF,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOpDF>;
//SwizzleThreadBlock,
//NumStages>;
Gemm::Arguments args({M , N, K}, // Gemm Problem dimensions
{static_cast<const cutlass::half_t *>(mat_in1.data_ptr()),mat_in1.stride(0)}, // <- reference to trans of matrix B on device
{static_cast<const cutlass::half_t *>(mat_in2.data_ptr()),mat_in2.stride(0)}, // <- reference to trans of matrix A on device
{static_cast<float *>(mat_out.data_ptr()),mat_out.stride(0)}, // <- reference to trans of matrix C on device
{static_cast<float *>(mat_out.data_ptr()),mat_out.stride(0)}, // <- reference to trans of matrix C on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices); // <- k-dimension split factor
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(args);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
// Initialize CUTLASS kernel with arguments and workspace pointer
cutlass::Status status = gemm_op.initialize(args, workspace.get());
CUTLASS_CHECK(status);
// Launch initialized CUTLASS kernel
status = gemm_op();
CUTLASS_CHECK(status);
}
else {
std::cout << "Only columnxrow and rowxrow supported for applysigmoid=False. TBD" << std::endl;
exit(-1);
}
} // end else apply_sigmoid
}