TensorFlow Java is still in an alpha stage, therefore is subject to contain breaking changes between the different releases. This guide explain in detail how to migrate your code from a previous version to a new one that includes some changes that are not backward compatible.
In previous versions, the Tensor
class was parameterized with its tensor type interface, which is part of the TType
family. To access directly the memory
tensor from the JVM, an explicit conversion between Tensor
and its tensor type was required by calling tensor.data()
.
In 0.3.0, tensors are always typed, making this generic parameter and explicit mapping obsolete. As soon as you get a handle to a tensor, you are able to
access directly its memory for reading (or writing for most tensor types) and no convertion is required. Any instances of a class in the TType
family
can also now be manipulated directly as a Tensor
(e.g. to be passed to a session for inference).
Steps:
- Replace a parameterized
Tensor
by its parameter (e.g.Tensor<TFloat32>
->TFloat32
) - Replace instance of
Tensor<?>
with unknown parameter byTensor
- Remove any invocation to
Tensor.data()
(e.g.tensor.data().getFloat()
->tensor.getFloat()
) - Replace any invocation to
Operand.data()
byOperand.asTensor()
In previous versions, the DataType
class was used to carry information about the type of a Tensor
, that can then be converted back to a tensor of that
type (see previous section). Since there were a exact parity between interfaces of the TType
family and an instance of DataType
, the latter has been dropped
in 0.3.0 to leverage instead the standard type system in Java, for a better idiomatic experience.
Steps:
- Replace all accesses to the
DTYPE
field of aTType
interface by its class (e.g.TFloat32.DTYPE
->TFloat32.class
) - Use Java type system for checking tensor types at runtime (e.g. using
instanceof
orisAssignableFrom
) - Replace any invocation to
Tensor.expect()
by an explicit cast (e.g.tensor.expect(TFloat32.DTYPE)
->(TFloat32)tensor
)
0.2.0:
Session session = ...;
try (Tensor<TFloat32> tensor = TFloat32.tensorOf(Shape.of(1, 2))) {
TFloat32 tensorData = tensor.data();
tensorData.setFloat(10.0f, 0);
tensorData.setFloat(20.0f, 1);
try (Tensor<?> result = session.runner().feed("x", tensor).fetch("y").run().get(0)) {
if (result.dataType() == TFloat32.DTYPE) {
Tensor<TFloat32> typedResult = result.expect(TFloat32.DTYPE);
TFloat32 resultData = typedResult.data();
System.out.println("Result is " + resultData.getFloat());
}
}
}
0.3.0:
Session session = ...;
try (TFloat32 tensor = TFloat32.tensorOf(Shape.of(1, 2))) {
tensor.setFloat(10.0f, 0);
tensor.setFloat(20.0f, 1);
try (Tensor result = session.runner().feed("x", tensor).fetch("y").run().get(0)) {
if (result instanceof TFloat32) {
TFloat32 typedResult = (TFloat32)result;
System.out.println("Result is " + typedResult.getFloat());
}
}
}