Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Optimize matmulnbits with M > 1 #23092

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 185 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,193 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
};
};

// Currently, only support blockSize = 32.
export const createMatMulNBitsWithLargeMProgramInfo = (
inputs: readonly TensorView[],
attributes: MatMulNBitsAttributes,
): ProgramInfo => {
const inputShape = inputs[0].dims;
const aRank = inputShape.length;
const dimAOuter = inputShape[aRank - 2];
const dimInner = attributes.k;
const dimBOuter = attributes.n;
const batchDims = inputShape.slice(0, aRank - 2);
const batchSize = ShapeUtil.size(batchDims);
const blobSize = inputs[1].dims[2];
const blobSizeInWords = blobSize / 4;
const dataType = inputs[0].dataType;
const aComponents = getMaxComponents(attributes.k);
const bComponents = getMaxComponents(blobSizeInWords);
const outputShape = batchDims.concat([dimAOuter, dimBOuter]);

const workgroupSize = 64;
const tileM = 4;
const workgroupY = 8;
const workgroupX = workgroupSize / workgroupY;
const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data.
const aLengthPerTile = tileSize / aComponents;
const blocksPerTile = tileSize / attributes.blockSize;

const programUniforms: ProgramUniform[] = [];
const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents];
const bShape = ShapeUtil.convertShape(inputs[1].dims).slice();
bShape.splice(-1, 1, blobSizeInWords / bComponents);
programUniforms.push(...createTensorShapeVariables(inputShapeTemp));
programUniforms.push(...createTensorShapeVariables(bShape));
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
if (inputs.length === 4) {
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
}
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter];
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));

const getShaderSource = (shaderHelper: ShaderHelper) => {
const inputRank = inputShapeTemp.length;
const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents);
const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
const inputVariables = [a, b, scales];
const zeroPoints =
inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
if (zeroPoints) {
inputVariables.push(zeroPoints);
}
const outputRank = outputShapeTemp.length;
const output = outputVariable('output', inputs[0].dataType, outputRank);
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const readA = () => {
switch (aComponents) {
case 1:
return `
let a_data0 = vec4<${dataType}>(sub_a[r][word_offset], sub_a[r][word_offset + 1], sub_a[r][word_offset + 2], sub_a[r][word_offset + 3]);
let a_data1 = vec4<${dataType}>(sub_a[r][word_offset + 4], sub_a[r][word_offset + 5], sub_a[r][word_offset + 6], sub_a[r][word_offset + 7]);`;
case 2:
return `
let a_data0 = vec4<${dataType}>(sub_a[r][word_offset], sub_a[r][word_offset + 1]);
let a_data1 = vec4<${dataType}>(sub_a[r][word_offset + 2], sub_a[r][word_offset + 3]);`;
case 4:
return `
let a_data0 = sub_a[r][word_offset];
let a_data1 = sub_a[r][word_offset + 1];`;
default:
throw new Error(`${aComponents}-component is not supported.`);
}
};

const loadTileA = () => {
let str = '';
for (let i = 0; i < tileM; i++) {
str += `sub_a[${i}][a_offset] = mm_readA(batch, row + ${i}, a_col);`;
}
return str;
};
return `
fn mm_readA(batch: u32, row : u32, col : u32) -> ${a.type.value} {
if (row < uniforms.a_shape[1] && col < uniforms.a_shape[2])
{
return ${a.getByIndices(`${a.type.indices}(batch, row, col)`)};
} else {
return ${a.type.value}(0);
}
}

var<workgroup> sub_a: array<array<${a.type.value}, ${aLengthPerTile}>, ${tileM}>;
var<workgroup> inter_results: array<array<array<${output.type.value}, ${workgroupX}>, ${workgroupY}>, ${tileM}>;
${shaderHelper.declareVariables(...inputVariables, output)}
${shaderHelper.mainStart([workgroupX, workgroupY, 1])}
let col = workgroup_id.x * ${workgroupY};
let row = workgroup_id.y * ${tileM};
let batch = workgroup_id.z;
let n_blocks_per_col = uniforms.b_shape[1];
let num_tiles = (n_blocks_per_col - 1) / ${blocksPerTile} + 1;

// Loop over shared dimension.
for (var tile: u32 = 0; tile < num_tiles; tile += 1) {
let a_col_start = tile * ${aLengthPerTile};
// load one tile A data into shared memory.
for (var a_offset = local_idx; a_offset < ${aLengthPerTile}; a_offset += ${workgroupSize})
{
let a_col = a_col_start + a_offset;
${loadTileA()}
}
workgroupBarrier();

// each thread process one block
let b_col = col + local_id.y;
let block = tile * ${blocksPerTile} + local_id.x;
${
zeroPoints
? `
let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;
let zero_point_byte_count = b_col * zero_point_bytes_per_col + (block >> 0x1u);
let zero_point_word_index = zero_point_byte_count >> 0x2u;
let zero_point_byte_offset = zero_point_byte_count & 0x3u;
let zero_point_nibble_offset: u32 = block & 0x1u;
let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
let zero_point = ${dataType}((zero_point_word) & 0xFu);`
: `
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point = ${dataType}(${8.0});`
}
let scale = ${scales.getByOffset(`b_col * n_blocks_per_col + block`)};
let b_data = ${b.getByIndices(`${b.type.indices}(b_col, block, 0)`)};
var word_offset = local_id.x * ${attributes.blockSize / aComponents};
for (var i: u32 = 0; i < ${bComponents}; i++) {
let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`};
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
let b_quantized_values = mat2x4<${dataType}>(${Array.from(
{ length: 4 },
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
).join(', ')});
let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale;
for (var r = 0; r < ${tileM}; r++) {
${readA()}
inter_results[r][local_id.y][local_id.x] += ${Array.from(
{ length: 2 },
(_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`,
).join(' + ')};
}
word_offset += ${8 / aComponents};
}
workgroupBarrier();
}

if (local_id.y < ${tileM}) {
var output_value: ${output.type.value} = ${output.type.value}(0);
for (var b = 0u; b < ${workgroupX}; b++) {
output_value += inter_results[local_id.y][local_id.x][b];
}
if (row + local_id.y < uniforms.output_shape[1] && col + local_id.x < uniforms.output_shape[2])
{
${output.setByIndices(`${output.type.indices}(batch, row + local_id.y, col + local_id.x)`, 'output_value')}
}
}
}`;
};
return {
name: 'MatMulNBitsWithLargeM',
shaderCache: {
hint: `${attributes.blockSize};${aComponents};${bComponents};${workgroupX};${workgroupY};${tileM}`,
inputDependencies: Array(inputs.length).fill('rank'),
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType }],
dispatchGroup: { x: Math.ceil(dimBOuter / workgroupY), y: Math.ceil(dimAOuter / tileM), z: batchSize },
programUniforms,
}),
getShaderSource,
};
};

export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
validateInputs(context.inputs, attributes);
if (
const inputShape = context.inputs[0].dims;
const m = inputShape[inputShape.length - 2];
if (m > 1 && attributes.blockSize === 32) {
context.compute(createMatMulNBitsWithLargeMProgramInfo(context.inputs, attributes));
} else if (
attributes.blockSize === 32 &&
context.adapterInfo.isVendor('intel') &&
context.adapterInfo.isArchitecture('gen-12lp')
Expand Down
Loading