diff --git a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/KIOperatorFactory.kt b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/KIOperatorFactory.kt index 7c081523f..ddae86ca9 100755 --- a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/KIOperatorFactory.kt +++ b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/KIOperatorFactory.kt @@ -20,8 +20,7 @@ import io.kinference.core.operators.convolution.* import io.kinference.core.operators.pool.* import io.kinference.core.operators.quantization.* import io.kinference.core.operators.quantization.lstm.DynamicQuantizeLSTM -import io.kinference.core.operators.seq.ConcatFromSequence -import io.kinference.core.operators.seq.SplitToSequence +import io.kinference.core.operators.seq.* import io.kinference.core.operators.tensor.* import io.kinference.graph.Graph import io.kinference.operator.* @@ -125,6 +124,7 @@ object KIOperatorFactory : OperatorFactory> { "Reshape" -> Reshape(name, version, attributes, inputs, outputs) "ScatterElements" -> ScatterElements(name, version, attributes, inputs, outputs) "ScatterND" -> ScatterND(name, version, attributes, inputs, outputs) + "SequenceLength" -> SequenceLength(name, version, attributes, inputs, outputs) "Shape" -> Shape(name, version, attributes, inputs, outputs) "Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs) "Sign" -> Sign(name, version, attributes, inputs, outputs) diff --git a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/seq/SequenceLength.kt b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/seq/SequenceLength.kt new file mode 100644 index 000000000..9521a7a10 --- /dev/null +++ b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/seq/SequenceLength.kt @@ -0,0 +1,56 @@ +package io.kinference.core.operators.seq + +import io.kinference.attribute.Attribute +import io.kinference.core.data.seq.KIONNXSequence +import io.kinference.core.data.tensor.KITensor +import io.kinference.core.data.tensor.asTensor +import io.kinference.data.ONNXData +import io.kinference.data.ONNXDataType +import io.kinference.graph.Contexts +import io.kinference.ndarray.arrays.LongNDArray +import io.kinference.operator.* +import io.kinference.protobuf.message.TensorProto + +sealed class SequenceLength( + name: String, + info: OperatorInfo, + attributes: Map>, + inputs: List, + outputs: List +) : Operator(name, info, attributes, inputs, outputs) { + companion object { + private val DEFAULT_VERSION = VersionInfo(sinceVersion = 11) + + operator fun invoke(name: String, version: Int?, attributes: Map>, inputs: List, outputs: List): SequenceLength { + return when (version ?: DEFAULT_VERSION.sinceVersion) { + in SequenceLengthVer11.VERSION.asRange() -> SequenceLengthVer11(name, attributes, inputs, outputs) + else -> error("Unsupported version of SequenceLength operator: $version") + } + } + } +} + + +class SequenceLengthVer11 internal constructor( + name: String, + attributes: Map>, + inputs: List, + outputs: List +) : SequenceLength(name, INFO, attributes, inputs, outputs) { + companion object { + private val TYPE_CONSTRAINTS = ALL_DATA_TYPES + + private val INPUTS_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "input_sequence", optional = false, onnxDataType = ONNXDataType.ONNX_SEQUENCE)) + + private val OUTPUTS_INFO = listOf(IOInfo(0, setOf(TensorProto.DataType.INT64), "length", optional = false)) + + internal val VERSION = VersionInfo(sinceVersion = 11) + private val INFO = OperatorInfo("SequenceLength", emptyMap(), INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) + } + + override suspend fun > apply(contexts: Contexts, inputs: List): List { + val seq = inputs.first()!!.data + val seqLength = LongNDArray.scalar(seq.size.toLong()) + return listOf(seqLength.asTensor("length")) + } +} diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt index 9340e4fcb..f61e0dbdf 100755 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt @@ -21,8 +21,7 @@ import io.kinference.tfjs.operators.math.* import io.kinference.tfjs.operators.ml.* import io.kinference.tfjs.operators.quantization.DequantizeLinear import io.kinference.tfjs.operators.quantization.DynamicQuantizeLinear -import io.kinference.tfjs.operators.seq.ConcatFromSequence -import io.kinference.tfjs.operators.seq.SplitToSequence +import io.kinference.tfjs.operators.seq.* import io.kinference.tfjs.operators.tensor.* object TFJSAttributeFactory : AttributeFactory> { @@ -108,6 +107,7 @@ object TFJSOperatorFactory : OperatorFactory> { "Reshape" -> Reshape(name, version, attributes, inputs, outputs) "ScatterElements" -> ScatterElements(name, version, attributes, inputs, outputs) "ScatterND" -> ScatterND(name, version, attributes, inputs, outputs) + "SequenceLength" -> SequenceLength(name, version, attributes, inputs, outputs) "Shape" -> Shape(name, version, attributes, inputs, outputs) "Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs) "Sign" -> Sign(name, version, attributes, inputs, outputs) diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/seq/SequenceLength.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/seq/SequenceLength.kt new file mode 100644 index 000000000..a429361f0 --- /dev/null +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/seq/SequenceLength.kt @@ -0,0 +1,56 @@ +package io.kinference.tfjs.operators.seq + +import io.kinference.attribute.Attribute +import io.kinference.data.ONNXData +import io.kinference.data.ONNXDataType +import io.kinference.graph.Contexts +import io.kinference.ndarray.arrays.NDArrayTFJS +import io.kinference.operator.* +import io.kinference.protobuf.message.TensorProto +import io.kinference.tfjs.data.seq.TFJSSequence +import io.kinference.tfjs.data.tensors.TFJSTensor +import io.kinference.tfjs.data.tensors.asTensor + +sealed class SequenceLength( + name: String, + info: OperatorInfo, + attributes: Map>, + inputs: List, + outputs: List +) : Operator(name, info, attributes, inputs, outputs) { + companion object { + private val DEFAULT_VERSION = VersionInfo(sinceVersion = 11) + + operator fun invoke(name: String, version: Int?, attributes: Map>, inputs: List, outputs: List): SequenceLength { + return when (version ?: DEFAULT_VERSION.sinceVersion) { + in SequenceLengthVer11.VERSION.asRange() -> SequenceLengthVer11(name, attributes, inputs, outputs) + else -> error("Unsupported version of SequenceLength operator: $version") + } + } + } +} + + +class SequenceLengthVer11 internal constructor( + name: String, + attributes: Map>, + inputs: List, + outputs: List +) : SequenceLength(name, INFO, attributes, inputs, outputs) { + companion object { + private val TYPE_CONSTRAINTS = ALL_DATA_TYPES + + private val INPUTS_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "input_sequence", optional = false, onnxDataType = ONNXDataType.ONNX_SEQUENCE)) + + private val OUTPUTS_INFO = listOf(IOInfo(0, setOf(TensorProto.DataType.INT64), "length", optional = false)) + + internal val VERSION = VersionInfo(sinceVersion = 11) + private val INFO = OperatorInfo("SequenceLength", emptyMap(), INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) + } + + override suspend fun > apply(contexts: Contexts, inputs: List): List { + val seq = inputs.first()!!.data + val seqLength = NDArrayTFJS.intScalar(seq.size) + return listOf(seqLength.asTensor("length")) + } +}