Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KI-43 [core, tfjs] Add Sin operator #111

Merged
merged 1 commit into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ object KIOperatorFactory : OperatorFactory<KIONNXData<*>> {
"ScatterND" -> ScatterND(name, version, attributes, inputs, outputs)
"Shape" -> Shape(name, version, attributes, inputs, outputs)
"Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs)
"Sin" -> Sin(name, version, attributes, inputs, outputs)
"Size" -> Size(name, version, attributes, inputs, outputs)
"SkipLayerNormalization" -> SkipLayerNormalization(name, version, attributes, inputs, outputs)
"Slice" -> Slice(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.kinference.core.operators.activations

import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.sin
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

sealed class Sin(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Activation(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 7)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) =
when (version ?: DEFAULT_VERSION.sinceVersion) {
in SinVer7.VERSION.asRange() -> SinVer7(name, attributes, inputs, outputs)
else -> error("Unsupported version of Sin operator: $version")
}
}
}


class SinVer7(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Sin(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = FLOAT_DATA_TYPES

private val INPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "input", optional = false))
private val OUTPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "output", optional = false))

internal val VERSION = VersionInfo(sinceVersion = 7)
private val INFO = OperatorInfo("Sin", emptySet(), INPUT_INFO, OUTPUT_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun activate(input: NDArrayCore, contexts: Contexts<KIONNXData<*>>): NDArrayCore {
return when (val type = input.type) {
DataType.FLOAT -> (input as FloatNDArray).sin()
DataType.DOUBLE -> (input as DoubleNDArray).sin()
else -> error("Unsupported data type for this operation: $type")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.operators.activations

import io.kinference.KITestEngine
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SinTest {
private fun getTargetPath(dirName: String) = "sin/$dirName/"

@Test
fun test_sin() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_sin"))
}

@Test
fun test_sin_example() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_sin_example"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ object TFJSOperatorFactory : OperatorFactory<TFJSData<*>> {
"ScatterND" -> ScatterND(name, version, attributes, inputs, outputs)
"Shape" -> Shape(name, version, attributes, inputs, outputs)
"Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs)
"Sin" -> Sin(name, version, attributes, inputs, outputs)
"Size" -> Size(name, version, attributes, inputs, outputs)
"SkipLayerNormalization" -> SkipLayerNormalization(name, version, attributes, inputs, outputs)
"Slice" -> Slice(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.kinference.tfjs.operators.activations

import io.kinference.attribute.Attribute
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.NumberNDArrayTFJS
import io.kinference.ndarray.extensions.sin
import io.kinference.operator.*
import io.kinference.tfjs.data.tensors.TFJSTensor
import io.kinference.tfjs.data.tensors.asTensor

sealed class Sin(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<TFJSTensor, TFJSTensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 7)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Sin {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in SinVer7.VERSION.asRange() -> SinVer7(name, attributes, inputs, outputs)
else -> error("Unsupported version of Sin operator: $version")
}
}
}
}

class SinVer7(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Sin(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = FLOAT_DATA_TYPES

private val ATTRIBUTES_INFO = emptyList<AttributeInfo>()

private val INPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "input", optional = false))
private val OUTPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "output", optional = false))

internal val VERSION = VersionInfo(sinceVersion = 7)
private val INFO = OperatorInfo("Sin", ATTRIBUTES_INFO, INPUT_INFO, OUTPUT_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<TFJSTensor?>): List<TFJSTensor?> {
val input = inputs[0]!!.data as NumberNDArrayTFJS
return listOf(input.sin().asTensor("output"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.tfjs.operators.activations

import io.kinference.tfjs.runners.TFJSTestEngine.TFJSAccuracyRunner
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SinTest {
private fun getTargetPath(dirName: String) = "sin/$dirName/"

@Test
fun test_sin() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_sin"))
}

@Test
fun test_sin_example() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_sin_example"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fun PrimitiveNDArray.cosh(): PrimitiveNDArray = applyElementWise { cosh(it) }

fun PrimitiveNDArray.asin(): PrimitiveNDArray = applyElementWise { asin(it) }
fun PrimitiveNDArray.asinh(): PrimitiveNDArray = applyElementWise { asinh(it) }
fun PrimitiveNDArray.sin(): PrimitiveNDArray = applyElementWise { sin(it) }

fun PrimitiveNDArray.atan(): PrimitiveNDArray = applyElementWise { atan(it) }
fun PrimitiveNDArray.atanh(): PrimitiveNDArray = applyElementWise { atanh(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ internal fun cosh(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationE

internal fun asin(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
internal fun asinh(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
internal fun sin(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()

internal fun atan(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
internal fun atanh(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,5 @@ internal external val floor: (x: ArrayTFJS) -> ArrayTFJS
internal external val isInf: (x: ArrayTFJS) -> ArrayTFJS

internal external val isNaN: (x: ArrayTFJS) -> ArrayTFJS

internal external val sin: (x: ArrayTFJS) -> ArrayTFJS
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ fun NumberNDArrayTFJS.atanh() = NumberNDArrayTFJS(tfjsArray.atanh())

fun NumberNDArrayTFJS.tan() = NumberNDArrayTFJS(tfjsArray.tan())

fun NumberNDArrayTFJS.sin() = NumberNDArrayTFJS(tfjsArray.sin())

fun NumberNDArrayTFJS.moments(axis: Int, keepDims: Boolean = false) = tfjsArray.moments(axis, keepDims).toNDArray()

fun NumberNDArrayTFJS.moments(axes: Array<Int>, keepDims: Boolean = false) = tfjsArray.moments(axes, keepDims).toNDArray()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,5 @@ internal fun ArrayTFJS.isInf() = isInf(this)
internal fun ArrayTFJS.isNaN() = isNaN(this)

internal fun ArrayTFJS.bandPart(numLower: Int = 0, numUpper: Int = 0) = linalg.bandPart(this, numLower, numUpper)

internal fun ArrayTFJS.sin() = sin(this)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test_data_set_0/input_0.pb
test_data_set_0/output_0.pb
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ByJ��;{?�t�>]mT?��H?R�t?�7T�_?P?f�]ӽv^�>��>�C~?m�0?v��=t��>L��>=??��P�a��>�A��!��?*�B?�-���C?LD~��\;=�>���?�~?3>� �>��F�m�j�#���*s>OMq?u�n?�c���n��4�]�u}���}�z�m?����E0پ.*s��3?#���4X���G�O0�>sN����l����F��>x �=Ճ�>���`���
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test_data_set_0/input_0.pb
test_data_set_0/output_0.pb
Binary file not shown.
Binary file not shown.
Binary file not shown.