Skip to content

Commit

Permalink
[WebNN EP] Use boolean flags instead of MLTensorUsage (#22497)
Browse files Browse the repository at this point in the history
Fixed #22495

We will keep MLTensorUsage until it is removed from Chromium.

---------

Co-authored-by: Dwayne Robinson <[email protected]>
  • Loading branch information
Honry and fdwr authored Oct 23, 2024
1 parent 63a07c1 commit e6e94e6
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
6 changes: 5 additions & 1 deletion js/web/lib/wasm/jsep/webnn/tensor-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class TensorIdTracker {

// eslint-disable-next-line no-bitwise
const usage = MLTensorUsage.READ | MLTensorUsage.WRITE;
this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage);
this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage, true, true);

if (copyOld && this.activeUpload) {
this.wrapper.write(this.activeUpload);
Expand Down Expand Up @@ -306,6 +306,8 @@ class TensorManagerImpl implements TensorManager {
dataType: MLOperandDataType,
shape: readonly number[],
usage: MLTensorUsageFlags,
writable: boolean,
readable: boolean,
): Promise<TensorWrapper> {
const sessionId = this.backend.currentSessionId;
for (const [index, tensor] of this.freeTensors.entries()) {
Expand All @@ -322,6 +324,8 @@ class TensorManagerImpl implements TensorManager {
shape,
dimensions: shape,
usage,
writable,
readable,
});
return new TensorWrapper({ sessionId, context, tensor, dataType, shape });
}
Expand Down
4 changes: 4 additions & 0 deletions js/web/lib/wasm/jsep/webnn/webnn.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ type MLNamedTensor = Record<string, MLTensor>;

type MLTensorUsageFlags = number;

// TODO(@Honry): Remove this once it is deprecated in Chromium.
declare const MLTensorUsage: {
readonly WEBGPU_INTEROP: MLTensorUsageFlags;
readonly READ: MLTensorUsageFlags;
Expand All @@ -400,6 +401,9 @@ declare const MLTensorUsage: {

interface MLTensorDescriptor extends MLOperandDescriptor {
usage: MLTensorUsageFlags;
importableToWebGPU?: boolean;
readable?: boolean;
writable?: boolean;
}

interface MLContext {
Expand Down
2 changes: 2 additions & 0 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty
// Assign both shape and dimensions while transitioning to new API.
dimensions: dims as number[],
usage: MLTensorUsage.READ,
readable: true,
});

return ort.Tensor.fromMLTensor(mlTensor, {
Expand All @@ -686,6 +687,7 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso
// Assign both shape and dimensions while transitioning to new API.
dimensions: cpuTensor.dims as number[],
usage: MLTensorUsage.WRITE,
writable: true,
});
mlContext.writeTensor(mlTensor, cpuTensor.data);
return ort.Tensor.fromMLTensor(mlTensor, {
Expand Down

0 comments on commit e6e94e6

Please sign in to comment.