Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improves 2d tiled matmulnbits by repeating A, loads N times for each B load #23071

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,14 @@ Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const
// Note in matmulnbits, B matrix is already transposed, however the following remains true
// for the shader below M describes A, N describes B and K is the hidden/shared dimension.
// K4/K8 are simply K divided by 4 or 8 respectively.
// A_REPEAT, number of times each workgroup reloads A sharing B.
shader.AdditionalImplementation() << R"INIT_SECTION(
// Matrix dimensions and quantization parameters
const TILE_SIZE : u32 = 16u;
const VALUES_PER_VEC4 : u32 = 4u;
const QUANTIZATION_BLOCK_SIZE : u32 = 32;
const A_REPEAT : u32 = 8u;

// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM,
// so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time.
// This uses all 16 lanes on 12th gen intel chips.
Expand All @@ -349,13 +352,10 @@ const VECTORIZED_QUANTIZATION_BLOCK_SIZE: u32 = 8u; // QUANTIZATION_BLOCK_SIZE /
//Shared memory
var<workgroup> tile_A : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
var<workgroup> tile_B : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
var<workgroup> tile_O : array<array<output_value_t, TILE_SIZE>, TILE_SIZE>;
var<workgroup> tile_O : array<array<output_value_t, TILE_SIZE>, TILE_SIZE * A_REPEAT>;

fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32)
{
if (a_global >= uniforms.M) {
return;
}
let local_A = input_a[a_global*uniforms.K4+step_idx*INNER_DIMENSION_ITEMS_PER_CYCLE+parallel_id];
tile_A[slot][parallel_id] = local_A;
}
Expand Down Expand Up @@ -417,21 +417,36 @@ fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t
// a single wave in this approach of indexing.
let idx = u32(local_idx / TILE_SIZE);
let idy = u32(local_idx % TILE_SIZE);
let a_global_base = workgroup_id.x * TILE_SIZE;
let a_global_base = workgroup_id.x * TILE_SIZE * A_REPEAT;
let b_global_base = workgroup_id.y * TILE_SIZE;
let step_count:u32 = u32(uniforms.K/(BLOCKS_PER_CYCLE*QUANTIZATION_BLOCK_SIZE));
for (var vec_step:u32 = 0; vec_step < step_count; vec_step++)
{
workgroupBarrier();
loadA(idx, a_global_base+idx, vec_step, idy);
loadB(idx, b_global_base+idx, vec_step, idy);
workgroupBarrier();
let result = computeDotProduct(idx, idy);
tile_O[idx][idy]+=result;
for (var repeat_offset:u32=0; repeat_offset<A_REPEAT*TILE_SIZE; repeat_offset+=TILE_SIZE)
{
let a_global = a_global_base+idx+repeat_offset;
if (a_global < uniforms.M)
{
loadA(idx, a_global_base+idx+repeat_offset, vec_step, idy);
let result = computeDotProduct(idx, idy);
tile_O[idx+repeat_offset][idy]+=result;
}
}
}
workgroupBarrier();
if (a_global_base+idx < uniforms.M && b_global_base+idy < uniforms.N) {
output[(a_global_base+idx) * uniforms.N + b_global_base + idy] = tile_O[idx][idy];
for (var a_repeat:u32=0; a_repeat<A_REPEAT; a_repeat++)
{
let ridx = a_repeat * TILE_SIZE + idx;
let a_global = a_global_base+ridx;
if (a_global < uniforms.M)
{
output[(a_global) * uniforms.N + b_global_base + idy] = tile_O[ridx][idy];
}
}
}
)MAIN_FN";
return Status::OK();
Expand Down Expand Up @@ -486,8 +501,13 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
// MatMulNBitsProgramPrefill does not use any of the subgroup wgsl instructions. The subgroup
// size just helps with optimal lane usage in the shader.
constexpr int32_t subgroup_size = 16;
// How many times each workgroup reloads A sharing B. This is tuneable,
// 8 produces a good performance for sequence length of 256/512, 16 will give
// slightly better performance for sequence lengths of 1024.
// Note: This should match A_REPEAT in the shader.
constexpr unsigned int kMatMulPrefillARepeat = 8;
program.SetWorkgroupSize(tile_size * subgroup_size);
program.SetDispatchGroupSize((M + tile_size - 1) / tile_size,
program.SetDispatchGroupSize((M + (tile_size * kMatMulPrefillARepeat) - 1) / (tile_size * kMatMulPrefillARepeat),
(N + tile_size - 1) / tile_size,
1);
program
Expand All @@ -506,7 +526,6 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
constexpr uint32_t output_number = 1;
MatMulNBitsProgram program{output_number, gsl::narrow<int>(components_b), has_zero_points, use_block32};

if (use_block32) {
components = 1;
constexpr uint32_t workgroup_size = 128;
Expand Down
Loading