Skip to content

Commit

Permalink
Merge pull request #110 from JetBrains-Research/size-op
Browse files Browse the repository at this point in the history
KI-40 Size operator
  • Loading branch information
cupertank authored Jul 17, 2023
2 parents dfcd7bb + fa5a47c commit c71879b
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ object KIOperatorFactory : OperatorFactory<KIONNXData<*>> {
"ScatterND" -> ScatterND(name, version, attributes, inputs, outputs)
"Shape" -> Shape(name, version, attributes, inputs, outputs)
"Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs)
"Size" -> Size(name, version, attributes, inputs, outputs)
"SkipLayerNormalization" -> SkipLayerNormalization(name, version, attributes, inputs, outputs)
"Slice" -> Slice(name, version, attributes, inputs, outputs)
"Softmax" -> Softmax(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package io.kinference.core.operators.tensor

import io.kinference.attribute.Attribute
import io.kinference.core.data.tensor.KITensor
import io.kinference.core.data.tensor.asTensor
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.operator.*
import io.kinference.protobuf.message.TensorProto

sealed class Size(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<KITensor, KITensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 1)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Size {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in SizeVer1.VERSION.asRange() -> SizeVer1(name, attributes, inputs, outputs)
else -> error("Unsupported version of Size operator: $version")
}
}
}
}


class SizeVer1(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Size(name, INFO, attributes, inputs, outputs) {
companion object {
private val ATTRIBUTES_INFO = emptyList<AttributeInfo>()

private val INPUTS_INFO = listOf(
IOInfo(0, PRIMITIVE_DATA_TYPES, "data", differentiable = true, optional = false)
)

private val OUTPUTS_INFO = listOf(
IOInfo(0, setOf(TensorProto.DataType.INT64), "size", differentiable = true, optional = false)
)

internal val VERSION = VersionInfo(sinceVersion = 1)
private val INFO = OperatorInfo("Size", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val inputSize = inputs[0]!!.data.linearSize
val dataSize = LongNDArray.scalar(inputSize.toLong())
return listOf(dataSize.asTensor("size"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ class TanTest {
private fun getTargetPath(dirName: String) = "tan/$dirName/"

@Test
fun test_tanh_example() = TestRunner.runTest {
fun test_tan_example() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_tan_example"))
}

@Test
fun test_tanh() = TestRunner.runTest {
fun test_tan() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_tan"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.operators.operations

import io.kinference.KITestEngine
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SizeTest {
private fun getTargetPath(dirName: String) = "size/$dirName/"

@Test
fun test_size_example() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_size_example"))
}

@Test
fun test_size() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_size"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ object TFJSOperatorFactory : OperatorFactory<TFJSData<*>> {
"ScatterND" -> ScatterND(name, version, attributes, inputs, outputs)
"Shape" -> Shape(name, version, attributes, inputs, outputs)
"Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs)
"Size" -> Size(name, version, attributes, inputs, outputs)
"SkipLayerNormalization" -> SkipLayerNormalization(name, version, attributes, inputs, outputs)
"Slice" -> Slice(name, version, attributes, inputs, outputs)
"Softmax" -> Softmax(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package io.kinference.tfjs.operators.tensor

import io.kinference.attribute.Attribute
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.NDArrayTFJS
import io.kinference.operator.*
import io.kinference.protobuf.message.TensorProto
import io.kinference.tfjs.data.tensors.TFJSTensor
import io.kinference.tfjs.data.tensors.asTensor

sealed class Size(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<TFJSTensor, TFJSTensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 1)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Size {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in SizeVer1.VERSION.asRange() -> SizeVer1(name, attributes, inputs, outputs)
else -> error("Unsupported version of Size operator: $version")
}
}
}
}


class SizeVer1(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Size(name, INFO, attributes, inputs, outputs) {
companion object {
private val ATTRIBUTES_INFO = emptyList<AttributeInfo>()

private val INPUTS_INFO = listOf(
IOInfo(0, PRIMITIVE_DATA_TYPES, "data", differentiable = true, optional = false)
)

private val OUTPUTS_INFO = listOf(
IOInfo(0, setOf(TensorProto.DataType.INT64), "size", differentiable = true, optional = false)
)

internal val VERSION = VersionInfo(sinceVersion = 1)
private val INFO = OperatorInfo("Size", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<TFJSTensor?>): List<TFJSTensor?> {
val inputSize = inputs[0]!!.data.linearSize
val dataSize = NDArrayTFJS.intScalar(inputSize)
return listOf(dataSize.asTensor("size"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ class TanTest {
private fun getTargetPath(dirName: String) = "tan/$dirName/"

@Test
fun test_tanh_example() = TestRunner.runTest {
fun test_tan_example() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tan_example"))
}

@Test
fun test_tanh() = TestRunner.runTest {
fun test_tan() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tan"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.tfjs.operators.tensor

import io.kinference.tfjs.runners.TFJSTestEngine.TFJSAccuracyRunner
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SizeTest {
private fun getTargetPath(dirName: String) = "size/$dirName/"

@Test
fun test_size_example() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_size_example"))
}

@Test
fun test_size() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_size"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test_data_set_0/input_0.pb
test_data_set_0/output_0.pb
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test_data_set_0/input_0.pb
test_data_set_0/output_0.pb
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit c71879b

Please sign in to comment.