Skip to content

Commit

Permalink
Improve code reusing
Browse files Browse the repository at this point in the history
  • Loading branch information
jzm-intel committed Oct 24, 2024
1 parent 6d225e0 commit 74c4d22
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 64 deletions.
37 changes: 3 additions & 34 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import { ShapeUtil } from '../../../util';
import { ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../../types';
import {
createTensorShapeVariables,
getElementAt,
IndicesHelper,
inputVariable,
internalVariable,
Expand All @@ -40,6 +39,7 @@ import {
getActivationSnippet,
InternalActivationAttributes,
} from '../fuse-utils';
import { convertOutputBatchIndicesToInput } from '../matmul';

import { typeSnippet } from './activation_util';

Expand Down Expand Up @@ -378,37 +378,6 @@ const matMulReadWriteFnSource = (
const [batchVariable, aVariable, bVariable, outputVariable] = variables;
const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor);

// Helper that convert output batch to input batch indices using only the rank and the shape information in uniform
const convertOutputBatchIndicesToInput = (
inputVariable: IndicesHelper,
targetIndicesName: string,
lastTwoIndices: [number | string, number | string],
) => {
const inputRank = inputVariable.rank;
const outputBatchRank = batchVariable.rank;
if (inputRank === 2) {
return `var ${targetIndicesName} = ${inputVariable.type.indices}(${lastTwoIndices[0]}, ${lastTwoIndices[1]});`;
}
const inputBatchRank = inputRank - 2;
// Assume outputBatchRank >= inputBatchRank, the first outputBatchRank - inputBatchRank of outputBatchRank
// should be ignored.
const extendingInputRank = outputBatchRank - inputBatchRank;
return `
var ${targetIndicesName}: ${inputVariable.type.indices};
${Array.from({ length: inputBatchRank })
.map(
(_, i) => `
if (${getElementAt(inputVariable.shape, i, inputRank)} != 1) {
${inputVariable.indicesSet(targetIndicesName, i, getElementAt('batchIndices', i + extendingInputRank, outputBatchRank))}
} else {
${inputVariable.indicesSet(targetIndicesName, i, 0)}
}`,
)
.join('')}
${inputVariable.indicesSet(targetIndicesName, inputRank - 2, lastTwoIndices[0])}
${inputVariable.indicesSet(targetIndicesName, inputRank - 1, lastTwoIndices[1])}
`;
};
const source = `
fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${typeSnippet(
component,
Expand All @@ -418,7 +387,7 @@ const matMulReadWriteFnSource = (
let col = colIn * ${component};
if(row < uniforms.dim_a_outer && col < uniforms.dim_inner)
{
${convertOutputBatchIndicesToInput(aVariable, 'aIndices', ['u32(row)', 'u32(colIn)'])}
${convertOutputBatchIndicesToInput(aVariable, 'aIndices', batchVariable.rank, 'batchIndices', ['u32(row)', 'u32(colIn)'])}
value = ${aVariable.getByIndices('aIndices')};
}
return value;
Expand All @@ -432,7 +401,7 @@ const matMulReadWriteFnSource = (
let col = colIn * ${component};
if(row < uniforms.dim_inner && col < uniforms.dim_b_outer)
{
${convertOutputBatchIndicesToInput(bVariable, 'bIndices', ['u32(row)', 'u32(colIn)'])}
${convertOutputBatchIndicesToInput(bVariable, 'bIndices', batchVariable.rank, 'batchIndices', ['u32(row)', 'u32(colIn)'])}
value = ${bVariable.getByIndices('bIndices')};
}
return value;
Expand Down
66 changes: 36 additions & 30 deletions js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,40 @@ import {
InternalActivationAttributes,
} from './fuse-utils';

// Helper that convert output batch indices to input batch indices using only the rank and
// the shape information in uniform
export const convertOutputBatchIndicesToInput = (
inputVariable: IndicesHelper,
targetIndicesName: string,
outputBatchRank: number,
batchIndicesName: string,
lastTwoInputIndices: [number | string, number | string] = [0, 0],
) => {
const inputRank = inputVariable.rank;
if (inputRank === 2) {
return `var ${targetIndicesName} = ${inputVariable.type.indices}(${lastTwoInputIndices[0]}, ${lastTwoInputIndices[1]});`;
}
const inputBatchRank = inputRank - 2;
// Assume outputBatchRank >= inputBatchRank, the first outputBatchRank - inputBatchRank of
// outputBatchRank should be ignored.
const extendingInputRank = outputBatchRank - inputBatchRank;
return `
var ${targetIndicesName}: ${inputVariable.type.indices};
${Array.from({ length: inputBatchRank })
.map(
(_, i) => `
if (${getElementAt(inputVariable.shape, i, inputRank)} != 1) {
${inputVariable.indicesSet(targetIndicesName, i, getElementAt(batchIndicesName, i + extendingInputRank, outputBatchRank))}
} else {
${inputVariable.indicesSet(targetIndicesName, i, 0)}
}`,
)
.join('')}
${inputVariable.indicesSet(targetIndicesName, inputRank - 2, lastTwoInputIndices[0])}
${inputVariable.indicesSet(targetIndicesName, inputRank - 1, lastTwoInputIndices[1])}
`;
};

export const createNaiveMatmulProgramInfo = (
inputs: readonly TensorView[],
activationAttributes: InternalActivationAttributes,
Expand Down Expand Up @@ -87,34 +121,6 @@ export const createNaiveMatmulProgramInfo = (
];
appendActivationUniforms(activationAttributes, uniforms);

// Helper that convert output batch to input batch indices using only the rank and the shape information in uniform
const convertOutputBatchIndicesToInput = (inputVariable: IndicesHelper, targetIndicesName: string) => {
const inputRank = inputVariable.rank;
const outputBatchRank = batchDims.rank;
if (inputRank === 2) {
return `var ${targetIndicesName} = ${inputVariable.type.indices}(0u, 0u);`;
}
const inputBatchRank = inputRank - 2;
// Assume outputBatchRank >= inputBatchRank, the first outputBatchRank - inputBatchRank of outputBatchRank
// should be ignored.
const extendingInputRank = outputBatchRank - inputBatchRank;
return `
var ${targetIndicesName}: ${inputVariable.type.indices};
${Array.from({ length: inputBatchRank })
.map(
(_, i) => `
if (${getElementAt(inputVariable.shape, i, inputRank)} != 1) {
${inputVariable.indicesSet(targetIndicesName, i, getElementAt('batch_indices', i + extendingInputRank, outputBatchRank))}
} else {
${inputVariable.indicesSet(targetIndicesName, i, 0)}
}`,
)
.join('')}
${inputVariable.indicesSet(targetIndicesName, inputRank - 2, 0)}
${inputVariable.indicesSet(targetIndicesName, inputRank - 1, 0)}
`;
};

const calcResult = (): string => {
let calcStr = `var a_data: ${a.type.value};`;
for (let i = 0; i < aComponents; i++) {
Expand Down Expand Up @@ -146,9 +152,9 @@ export const createNaiveMatmulProgramInfo = (
let batch = index1 / stride1;
${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`}
${convertOutputBatchIndicesToInput(a, 'a_indices')}
${convertOutputBatchIndicesToInput(a, 'a_indices', batchDims.rank, 'batch_indices', [0, 0])}
let a_offset = ${a.indicesToOffset('a_indices')};
${convertOutputBatchIndicesToInput(b, 'b_indices')}
${convertOutputBatchIndicesToInput(b, 'b_indices', batchDims.rank, 'batch_indices', [0, 0])}
let b_offset = ${b.indicesToOffset('b_indices')};
var values: array<${output.type.value}, ${outputNumber}>;
for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) {
Expand Down

0 comments on commit 74c4d22

Please sign in to comment.