Skip to content

Commit

Permalink
Use output buffer instead of computing softmax in-place.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Oct 25, 2024
1 parent 27bc573 commit 2dff8d5
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ const initVarStub = (
}
};

const createInPlaceSoftmaxProgramInfo = (
const createSoftmaxProgramInfo = (
input: TensorView,
batchSize: number,
numHeads: number,
Expand Down Expand Up @@ -324,7 +324,8 @@ const createInPlaceSoftmaxProgramInfo = (
inputDependencies.push('type');
}
const getShaderSource = (shaderHelper: ShaderHelper) => {
const inputHelper = outputVariable('x', input.dataType, input.dims, components);
const inputHelper = inputVariable('x', input.dataType, input.dims, components);
const outputHelper = outputVariable('y', input.dataType, input.dims, components);
const inputHelpers = [inputHelper];
const seqLensInputHelper = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined;
if (seqLensInputHelper) {
Expand All @@ -350,7 +351,7 @@ const createInPlaceSoftmaxProgramInfo = (
return `
var<workgroup> thread_max: array<f32, ${WG}>;
var<workgroup> thread_sum: array<f32, ${WG}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputHelpers)}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputHelpers, outputHelper)}
${shaderHelper.mainStart([WG, 1, 1])}
let batchIdx = workgroup_id.z / uniforms.num_heads;
let headIdx = workgroup_id.z % uniforms.num_heads;
Expand Down Expand Up @@ -408,19 +409,19 @@ const createInPlaceSoftmaxProgramInfo = (
if (sum == 0) {
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
x[offset + i] = ${inputHelper.type.value}(${elemValueType}(1.0) / ${elemValueType}(seq_causal_length));
y[offset + i] = ${inputHelper.type.value}(${elemValueType}(1.0) / ${elemValueType}(seq_causal_length));
}
} else {
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
var f32input = ${f32Type}(x[offset + i]);
x[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum);
y[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum);
}
}
${
seqLens
? `
for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) {
x[offset + total_seq_id] = ${inputHelper.type.value}(${elemValueType}(0));
y[offset + total_seq_id] = ${inputHelper.type.value}(${elemValueType}(0));
}`
: ''
};
Expand All @@ -432,7 +433,7 @@ const createInPlaceSoftmaxProgramInfo = (
shaderCache: { hint: `${WG};${dataType};${components}`, inputDependencies },
getShaderSource,
getRunData: () => ({
outputs: [],
outputs: [{ dims: input.dims, dataType: input.dataType, gpuDataType: GpuDataType.default }],
dispatchGroup: { x: Math.ceil(totalSequenceLength / WG), y: sequenceLength, z: batchSize * numHeads },
programUniforms,
}),
Expand Down Expand Up @@ -844,8 +845,8 @@ export const applyAttention = (
)[0];

// Run Softmax
context.compute(
createInPlaceSoftmaxProgramInfo(
const softmaxOutput = context.compute(
createSoftmaxProgramInfo(
probs,
parameters.batchSize,
parameters.numHeads,
Expand All @@ -855,11 +856,11 @@ export const applyAttention = (
seqLens,
totalSequenceLengthInput,
),
{ inputs: seqLens && totalSequenceLengthInput ? [probs, seqLens, totalSequenceLengthInput] : [probs], outputs: [] },
);
{ inputs: seqLens && totalSequenceLengthInput ? [probs, seqLens, totalSequenceLengthInput] : [probs], outputs: [-1] },
)[0];

// Run AttentionScore
const inputsV = [probs, v];
const inputsV = [softmaxOutput, v];
if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
inputsV.push(pastValue);
}
Expand Down

0 comments on commit 2dff8d5

Please sign in to comment.