Skip to content

Commit

Permalink
KI-44 [core, tfjs] Add Sinh operator
Browse files Browse the repository at this point in the history
  • Loading branch information
AnastasiaTuchina committed Jul 17, 2023
1 parent c71879b commit bc972b4
Show file tree
Hide file tree
Showing 19 changed files with 160 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ object KIOperatorFactory : OperatorFactory<KIONNXData<*>> {
"ScatterND" -> ScatterND(name, version, attributes, inputs, outputs)
"Shape" -> Shape(name, version, attributes, inputs, outputs)
"Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs)
"Sinh" -> Sinh(name, version, attributes, inputs, outputs)
"Size" -> Size(name, version, attributes, inputs, outputs)
"SkipLayerNormalization" -> SkipLayerNormalization(name, version, attributes, inputs, outputs)
"Slice" -> Slice(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.kinference.core.operators.activations

import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.sinh
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

sealed class Sinh(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Activation(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 9)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) =
when (version ?: DEFAULT_VERSION.sinceVersion) {
in SinhVer9.VERSION.asRange() -> SinhVer9(name, attributes, inputs, outputs)
else -> error("Unsupported version of Sinh operator: $version")
}
}
}


class SinhVer9(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Sinh(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = FLOAT_DATA_TYPES

private val INPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "input", optional = false))
private val OUTPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "output", optional = false))

internal val VERSION = VersionInfo(sinceVersion = 9)
private val INFO = OperatorInfo("Sinh", emptySet(), INPUT_INFO, OUTPUT_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun activate(input: NDArrayCore, contexts: Contexts<KIONNXData<*>>): NDArrayCore {
return when (val type = input.type) {
DataType.FLOAT -> (input as FloatNDArray).sinh()
DataType.DOUBLE -> (input as DoubleNDArray).sinh()
else -> error("Unsupported data type for this operation: $type")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.operators.activations

import io.kinference.KITestEngine
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SinhTest {
private fun getTargetPath(dirName: String) = "sinh/$dirName/"

@Test
fun test_sinh() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_sinh"))
}

@Test
fun test_sinh_example() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_sinh_example"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ object TFJSOperatorFactory : OperatorFactory<TFJSData<*>> {
"ScatterND" -> ScatterND(name, version, attributes, inputs, outputs)
"Shape" -> Shape(name, version, attributes, inputs, outputs)
"Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs)
"Sinh" -> Sinh(name, version, attributes, inputs, outputs)
"Size" -> Size(name, version, attributes, inputs, outputs)
"SkipLayerNormalization" -> SkipLayerNormalization(name, version, attributes, inputs, outputs)
"Slice" -> Slice(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.kinference.tfjs.operators.activations

import io.kinference.attribute.Attribute
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.NumberNDArrayTFJS
import io.kinference.ndarray.extensions.sinh
import io.kinference.operator.*
import io.kinference.tfjs.data.tensors.TFJSTensor
import io.kinference.tfjs.data.tensors.asTensor

sealed class Sinh(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<TFJSTensor, TFJSTensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 9)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Sinh {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in SinhVer9.VERSION.asRange() -> SinhVer9(name, attributes, inputs, outputs)
else -> error("Unsupported version of Sinh operator: $version")
}
}
}
}

class SinhVer9(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Sinh(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = FLOAT_DATA_TYPES

private val ATTRIBUTES_INFO = emptyList<AttributeInfo>()

private val INPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "input", optional = false))
private val OUTPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "output", optional = false))

internal val VERSION = VersionInfo(sinceVersion = 9)
private val INFO = OperatorInfo("Sinh", ATTRIBUTES_INFO, INPUT_INFO, OUTPUT_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<TFJSTensor?>): List<TFJSTensor?> {
val input = inputs[0]!!.data as NumberNDArrayTFJS
return listOf(input.sinh().asTensor("output"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.tfjs.operators.activations

import io.kinference.tfjs.runners.TFJSTestEngine.TFJSAccuracyRunner
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SinhTest {
private fun getTargetPath(dirName: String) = "sinh/$dirName/"

@Test
fun test_sinh() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_sinh"))
}

@Test
fun test_sinh_example() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_sinh_example"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fun PrimitiveNDArray.cosh(): PrimitiveNDArray = applyElementWise { cosh(it) }

fun PrimitiveNDArray.asin(): PrimitiveNDArray = applyElementWise { asin(it) }
fun PrimitiveNDArray.asinh(): PrimitiveNDArray = applyElementWise { asinh(it) }
fun PrimitiveNDArray.sinh(): PrimitiveNDArray = applyElementWise { sinh(it) }

fun PrimitiveNDArray.atan(): PrimitiveNDArray = applyElementWise { atan(it) }
fun PrimitiveNDArray.atanh(): PrimitiveNDArray = applyElementWise { atanh(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ internal fun cosh(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationE

internal fun asin(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
internal fun asinh(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
internal fun sinh(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()

internal fun atan(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
internal fun atanh(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,5 @@ internal external val floor: (x: ArrayTFJS) -> ArrayTFJS
internal external val isInf: (x: ArrayTFJS) -> ArrayTFJS

internal external val isNaN: (x: ArrayTFJS) -> ArrayTFJS

internal external val sinh: (x: ArrayTFJS) -> ArrayTFJS
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ fun NumberNDArrayTFJS.asin() = NumberNDArrayTFJS(tfjsArray.asin())

fun NumberNDArrayTFJS.asinh() = NumberNDArrayTFJS(tfjsArray.asinh())

fun NumberNDArrayTFJS.sinh() = NumberNDArrayTFJS(tfjsArray.sinh())

fun NumberNDArrayTFJS.atan() = NumberNDArrayTFJS(tfjsArray.atan())

fun NumberNDArrayTFJS.atanh() = NumberNDArrayTFJS(tfjsArray.atanh())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,8 @@ internal fun ArrayTFJS.asin() = asin(this)

internal fun ArrayTFJS.asinh() = asinh(this)

internal fun ArrayTFJS.sinh() = sinh(this)

internal fun ArrayTFJS.atan() = atan(this)

internal fun ArrayTFJS.atanh() = atanh(this)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test_data_set_0/input_0.pb
test_data_set_0/output_0.pb
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ByJ�$E5@gd�>�B�?̹�@�,J@����6��?����ӽ�.�>�>��@//V?F��=E��>w�>�b@�S��>�7v�H��a3?��y?6�O��-�@��~;=,�@�uH@W�@DM>cD�>�)����c����"� >0R�?^��?�N˾�%��`2���O��ru*��\@g1����ݸͿ��[?�W�5~[�\�����>��¼�>����>�S�=�<�>V-��ҽ�
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test_data_set_0/input_0.pb
test_data_set_0/output_0.pb
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit bc972b4

Please sign in to comment.