Skip to content

Commit

Permalink
Merge branch 'master' into sign-op
Browse files Browse the repository at this point in the history
  • Loading branch information
cupertank authored Jul 18, 2023
2 parents a362f3c + 4f4e144 commit 3472434
Show file tree
Hide file tree
Showing 19 changed files with 160 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ object KIOperatorFactory : OperatorFactory<KIONNXData<*>> {
"Shape" -> Shape(name, version, attributes, inputs, outputs)
"Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs)
"Sign" -> Sign(name, version, attributes, inputs, outputs)
"Sin" -> Sin(name, version, attributes, inputs, outputs)
"Size" -> Size(name, version, attributes, inputs, outputs)
"SkipLayerNormalization" -> SkipLayerNormalization(name, version, attributes, inputs, outputs)
"Slice" -> Slice(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.kinference.core.operators.activations

import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.sin
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

sealed class Sin(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Activation(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 7)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) =
when (version ?: DEFAULT_VERSION.sinceVersion) {
in SinVer7.VERSION.asRange() -> SinVer7(name, attributes, inputs, outputs)
else -> error("Unsupported version of Sin operator: $version")
}
}
}


class SinVer7(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Sin(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = FLOAT_DATA_TYPES

private val INPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "input", optional = false))
private val OUTPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "output", optional = false))

internal val VERSION = VersionInfo(sinceVersion = 7)
private val INFO = OperatorInfo("Sin", emptySet(), INPUT_INFO, OUTPUT_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun activate(input: NDArrayCore, contexts: Contexts<KIONNXData<*>>): NDArrayCore {
return when (val type = input.type) {
DataType.FLOAT -> (input as FloatNDArray).sin()
DataType.DOUBLE -> (input as DoubleNDArray).sin()
else -> error("Unsupported data type for this operation: $type")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.operators.activations

import io.kinference.KITestEngine
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SinTest {
private fun getTargetPath(dirName: String) = "sin/$dirName/"

@Test
fun test_sin() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_sin"))
}

@Test
fun test_sin_example() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_sin_example"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ object TFJSOperatorFactory : OperatorFactory<TFJSData<*>> {
"Shape" -> Shape(name, version, attributes, inputs, outputs)
"Sigmoid" -> Sigmoid(name, version, attributes, inputs, outputs)
"Sign" -> Sign(name, version, attributes, inputs, outputs)
"Sin" -> Sin(name, version, attributes, inputs, outputs)
"Size" -> Size(name, version, attributes, inputs, outputs)
"SkipLayerNormalization" -> SkipLayerNormalization(name, version, attributes, inputs, outputs)
"Slice" -> Slice(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.kinference.tfjs.operators.activations

import io.kinference.attribute.Attribute
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.NumberNDArrayTFJS
import io.kinference.ndarray.extensions.sin
import io.kinference.operator.*
import io.kinference.tfjs.data.tensors.TFJSTensor
import io.kinference.tfjs.data.tensors.asTensor

sealed class Sin(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<TFJSTensor, TFJSTensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 7)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Sin {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in SinVer7.VERSION.asRange() -> SinVer7(name, attributes, inputs, outputs)
else -> error("Unsupported version of Sin operator: $version")
}
}
}
}

class SinVer7(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Sin(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = FLOAT_DATA_TYPES

private val ATTRIBUTES_INFO = emptyList<AttributeInfo>()

private val INPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "input", optional = false))
private val OUTPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "output", optional = false))

internal val VERSION = VersionInfo(sinceVersion = 7)
private val INFO = OperatorInfo("Sin", ATTRIBUTES_INFO, INPUT_INFO, OUTPUT_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<TFJSTensor?>): List<TFJSTensor?> {
val input = inputs[0]!!.data as NumberNDArrayTFJS
return listOf(input.sin().asTensor("output"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.tfjs.operators.activations

import io.kinference.tfjs.runners.TFJSTestEngine.TFJSAccuracyRunner
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SinTest {
private fun getTargetPath(dirName: String) = "sin/$dirName/"

@Test
fun test_sin() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_sin"))
}

@Test
fun test_sin_example() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_sin_example"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fun PrimitiveNDArray.cosh(): PrimitiveNDArray = applyElementWise { cosh(it) }

fun PrimitiveNDArray.asin(): PrimitiveNDArray = applyElementWise { asin(it) }
fun PrimitiveNDArray.asinh(): PrimitiveNDArray = applyElementWise { asinh(it) }
fun PrimitiveNDArray.sin(): PrimitiveNDArray = applyElementWise { sin(it) }

fun PrimitiveNDArray.atan(): PrimitiveNDArray = applyElementWise { atan(it) }
fun PrimitiveNDArray.atanh(): PrimitiveNDArray = applyElementWise { atanh(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ internal fun cosh(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationE

internal fun asin(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
internal fun asinh(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
internal fun sin(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()

internal fun atan(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
internal fun atanh(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,5 @@ internal external val isInf: (x: ArrayTFJS) -> ArrayTFJS
internal external val isNaN: (x: ArrayTFJS) -> ArrayTFJS

internal external val sign: (x: ArrayTFJS) -> ArrayTFJS

internal external val sin: (x: ArrayTFJS) -> ArrayTFJS
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ fun NumberNDArrayTFJS.atanh() = NumberNDArrayTFJS(tfjsArray.atanh())

fun NumberNDArrayTFJS.tan() = NumberNDArrayTFJS(tfjsArray.tan())

fun NumberNDArrayTFJS.sin() = NumberNDArrayTFJS(tfjsArray.sin())

fun NumberNDArrayTFJS.moments(axis: Int, keepDims: Boolean = false) = tfjsArray.moments(axis, keepDims).toNDArray()

fun NumberNDArrayTFJS.moments(axes: Array<Int>, keepDims: Boolean = false) = tfjsArray.moments(axes, keepDims).toNDArray()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,5 @@ internal fun ArrayTFJS.isNaN() = isNaN(this)
internal fun ArrayTFJS.bandPart(numLower: Int = 0, numUpper: Int = 0) = linalg.bandPart(this, numLower, numUpper)

internal fun ArrayTFJS.sign() = sign(this)

internal fun ArrayTFJS.sin() = sin(this)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test_data_set_0/input_0.pb
test_data_set_0/output_0.pb
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ByJ��;{?�t�>]mT?��H?R�t?�7T�_?P?f�]ӽv^�>��>�C~?m�0?v��=t��>L��>=??��P�a��>�A��!��?*�B?�-���C?LD~��\;=�>���?�~?3>� �>��F�m�j�#���*s>OMq?u�n?�c���n��4�]�u}���}�z�m?����E0پ.*s��3?#���4X���G�O0�>sN����l����F��>x �=Ճ�>���`���
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test_data_set_0/input_0.pb
test_data_set_0/output_0.pb
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 3472434

Please sign in to comment.