Skip to content

Commit

Permalink
KI-40 [tfjs] Add Size operator
Browse files Browse the repository at this point in the history
  • Loading branch information
AnastasiaTuchina committed Jul 14, 2023
1 parent 9e8ee74 commit aec6ea5
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.operator.*
import io.kinference.protobuf.message.TensorProto

sealed class Size(
name: String,
Expand Down Expand Up @@ -42,7 +43,7 @@ class SizeVer1(
)

private val OUTPUTS_INFO = listOf(
IOInfo(0, PRIMITIVE_DATA_TYPES, "size", differentiable = true, optional = false)
IOInfo(0, setOf(TensorProto.DataType.INT64), "size", differentiable = true, optional = false)
)

internal val VERSION = VersionInfo(sinceVersion = 1)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ object TFJSOperatorFactory : OperatorFactory<TFJSData<*>> {
"ScatterND" -> ScatterND(name, version, attributes, inputs, outputs)
"Shape" -> Shape(name, version, attributes, inputs, outputs)
"Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs)
"Size" -> Size(name, version, attributes, inputs, outputs)
"SkipLayerNormalization" -> SkipLayerNormalization(name, version, attributes, inputs, outputs)
"Slice" -> Slice(name, version, attributes, inputs, outputs)
"Softmax" -> Softmax(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package io.kinference.tfjs.operators.tensor

import io.kinference.attribute.Attribute
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.NDArrayTFJS
import io.kinference.operator.*
import io.kinference.protobuf.message.TensorProto
import io.kinference.tfjs.data.tensors.TFJSTensor
import io.kinference.tfjs.data.tensors.asTensor

sealed class Size(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<TFJSTensor, TFJSTensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 1)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Size {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in SizeVer1.VERSION.asRange() -> SizeVer1(name, attributes, inputs, outputs)
else -> error("Unsupported version of Size operator: $version")
}
}
}
}


class SizeVer1(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Size(name, INFO, attributes, inputs, outputs) {
companion object {
private val ATTRIBUTES_INFO = emptyList<AttributeInfo>()

private val INPUTS_INFO = listOf(
IOInfo(0, PRIMITIVE_DATA_TYPES, "data", differentiable = true, optional = false)
)

private val OUTPUTS_INFO = listOf(
IOInfo(0, setOf(TensorProto.DataType.INT64), "size", differentiable = true, optional = false)
)

internal val VERSION = VersionInfo(sinceVersion = 1)
private val INFO = OperatorInfo("Size", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<TFJSTensor?>): List<TFJSTensor?> {
val inputSize = inputs[0]!!.data.linearSize
val dataSize = NDArrayTFJS.intScalar(inputSize.toLong())
return listOf(dataSize.asTensor("size"))
}
}

This file was deleted.

Binary file not shown.

This file was deleted.

Binary file not shown.

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit aec6ea5

Please sign in to comment.