Skip to content

Commit

Permalink
Update Java Tensor (#993)
Browse files Browse the repository at this point in the history
- Save reference to ByteBuffer that owns the data so it won't get garbage collected while the Tensor is around.
- Verify the passed in ByteBuffer has native byte order.
- Change default nativeHandle value to 0.
  • Loading branch information
edgchen1 authored Oct 21, 2024
1 parent dfbe14c commit 8e7f92c
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/java/src/main/java/ai/onnxruntime/genai/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
package ai.onnxruntime.genai;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;

public final class Tensor implements AutoCloseable {
private long nativeHandle = -1;
private long nativeHandle = 0;
private final ElementType elementType;
private final long[] shape;

// Buffer that owns the Tensor data.
private ByteBuffer dataBuffer = null;

// The values in this enum must match ONNX values
// https://github.com/onnx/onnx/blob/159fa47b7c4d40e6d9740fcf14c36fff1d11ccd8/onnx/onnx.proto#L499-L544
public enum ElementType {
Expand All @@ -33,7 +37,7 @@ public enum ElementType {
/**
* Constructs a Tensor with the given data, shape and element type.
*
* @param data The data for the Tensor. Must be a direct ByteBuffer.
* @param data The data for the Tensor. Must be a direct ByteBuffer with native byte order.
* @param shape The shape of the Tensor.
* @param elementType The type of elements in the Tensor.
* @throws GenAIException
Expand All @@ -51,8 +55,14 @@ public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws Gen
"Tensor data must be direct. Allocate with ByteBuffer.allocateDirect");
}

// for now, require native byte order as the bytes will be used directly.
if (data.order() != ByteOrder.nativeOrder()) {
throw new GenAIException("Tensor data must have native byte order.");
}

this.elementType = elementType;
this.shape = shape;
this.dataBuffer = data; // save a reference so the owning buffer will stay around.

nativeHandle = createTensor(data, shape, elementType.ordinal());
}
Expand Down

0 comments on commit 8e7f92c

Please sign in to comment.