Skip to content

Commit

Permalink
Merge pull request #115 from JetBrains-Research/seqLen-op
Browse files Browse the repository at this point in the history
KI-46 [core, tfjs] Add SequenceLength operator
  • Loading branch information
cupertank authored Jul 19, 2023
2 parents 7f36c0f + 1439a47 commit fa63cb5
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ import io.kinference.core.operators.convolution.*
import io.kinference.core.operators.pool.*
import io.kinference.core.operators.quantization.*
import io.kinference.core.operators.quantization.lstm.DynamicQuantizeLSTM
import io.kinference.core.operators.seq.ConcatFromSequence
import io.kinference.core.operators.seq.SplitToSequence
import io.kinference.core.operators.seq.*
import io.kinference.core.operators.tensor.*
import io.kinference.graph.Graph
import io.kinference.operator.*
Expand Down Expand Up @@ -125,6 +124,7 @@ object KIOperatorFactory : OperatorFactory<KIONNXData<*>> {
"Reshape" -> Reshape(name, version, attributes, inputs, outputs)
"ScatterElements" -> ScatterElements(name, version, attributes, inputs, outputs)
"ScatterND" -> ScatterND(name, version, attributes, inputs, outputs)
"SequenceLength" -> SequenceLength(name, version, attributes, inputs, outputs)
"Shape" -> Shape(name, version, attributes, inputs, outputs)
"Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs)
"Sign" -> Sign(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.kinference.core.operators.seq

import io.kinference.attribute.Attribute
import io.kinference.core.data.seq.KIONNXSequence
import io.kinference.core.data.tensor.KITensor
import io.kinference.core.data.tensor.asTensor
import io.kinference.data.ONNXData
import io.kinference.data.ONNXDataType
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.LongNDArray
import io.kinference.operator.*
import io.kinference.protobuf.message.TensorProto

sealed class SequenceLength(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<KIONNXSequence, KITensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 11)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): SequenceLength {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in SequenceLengthVer11.VERSION.asRange() -> SequenceLengthVer11(name, attributes, inputs, outputs)
else -> error("Unsupported version of SequenceLength operator: $version")
}
}
}
}


class SequenceLengthVer11 internal constructor(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : SequenceLength(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = ALL_DATA_TYPES

private val INPUTS_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "input_sequence", optional = false, onnxDataType = ONNXDataType.ONNX_SEQUENCE))

private val OUTPUTS_INFO = listOf(IOInfo(0, setOf(TensorProto.DataType.INT64), "length", optional = false))

internal val VERSION = VersionInfo(sinceVersion = 11)
private val INFO = OperatorInfo("SequenceLength", emptyMap(), INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KIONNXSequence?>): List<KITensor?> {
val seq = inputs.first()!!.data
val seqLength = LongNDArray.scalar(seq.size.toLong())
return listOf(seqLength.asTensor("length"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ import io.kinference.tfjs.operators.math.*
import io.kinference.tfjs.operators.ml.*
import io.kinference.tfjs.operators.quantization.DequantizeLinear
import io.kinference.tfjs.operators.quantization.DynamicQuantizeLinear
import io.kinference.tfjs.operators.seq.ConcatFromSequence
import io.kinference.tfjs.operators.seq.SplitToSequence
import io.kinference.tfjs.operators.seq.*
import io.kinference.tfjs.operators.tensor.*

object TFJSAttributeFactory : AttributeFactory<TFJSData<*>> {
Expand Down Expand Up @@ -108,6 +107,7 @@ object TFJSOperatorFactory : OperatorFactory<TFJSData<*>> {
"Reshape" -> Reshape(name, version, attributes, inputs, outputs)
"ScatterElements" -> ScatterElements(name, version, attributes, inputs, outputs)
"ScatterND" -> ScatterND(name, version, attributes, inputs, outputs)
"SequenceLength" -> SequenceLength(name, version, attributes, inputs, outputs)
"Shape" -> Shape(name, version, attributes, inputs, outputs)
"Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs)
"Sign" -> Sign(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.kinference.tfjs.operators.seq

import io.kinference.attribute.Attribute
import io.kinference.data.ONNXData
import io.kinference.data.ONNXDataType
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.NDArrayTFJS
import io.kinference.operator.*
import io.kinference.protobuf.message.TensorProto
import io.kinference.tfjs.data.seq.TFJSSequence
import io.kinference.tfjs.data.tensors.TFJSTensor
import io.kinference.tfjs.data.tensors.asTensor

sealed class SequenceLength(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<TFJSSequence, TFJSTensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 11)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): SequenceLength {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in SequenceLengthVer11.VERSION.asRange() -> SequenceLengthVer11(name, attributes, inputs, outputs)
else -> error("Unsupported version of SequenceLength operator: $version")
}
}
}
}


class SequenceLengthVer11 internal constructor(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : SequenceLength(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = ALL_DATA_TYPES

private val INPUTS_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "input_sequence", optional = false, onnxDataType = ONNXDataType.ONNX_SEQUENCE))

private val OUTPUTS_INFO = listOf(IOInfo(0, setOf(TensorProto.DataType.INT64), "length", optional = false))

internal val VERSION = VersionInfo(sinceVersion = 11)
private val INFO = OperatorInfo("SequenceLength", emptyMap(), INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<TFJSSequence?>): List<TFJSTensor?> {
val seq = inputs.first()!!.data
val seqLength = NDArrayTFJS.intScalar(seq.size)
return listOf(seqLength.asTensor("length"))
}
}

0 comments on commit fa63cb5

Please sign in to comment.