diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt index b2cd50abf..c91375e49 100755 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt @@ -114,6 +114,7 @@ object TFJSOperatorFactory : OperatorFactory> { "Softmax" -> Softmax(name, version, attributes, inputs, outputs) "Split" -> Split(name, version, attributes, inputs, outputs) "SplitToSequence" -> SplitToSequence(name, version, attributes, inputs, outputs) + "Sqrt" -> Sqrt(name, version, attributes, inputs, outputs) "Squeeze" -> Squeeze(name, version, attributes, inputs, outputs) "Sub" -> Sub(name, version, attributes, inputs, outputs) "Tan" -> Tan(name, version, attributes, inputs, outputs) diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/math/Sqrt.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/math/Sqrt.kt new file mode 100644 index 000000000..97a20e600 --- /dev/null +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/math/Sqrt.kt @@ -0,0 +1,58 @@ +package io.kinference.tfjs.operators.math + +import io.kinference.attribute.Attribute +import io.kinference.data.ONNXData +import io.kinference.graph.Contexts +import io.kinference.ndarray.arrays.NumberNDArrayTFJS +import io.kinference.ndarray.extensions.sqrt +import io.kinference.operator.* +import io.kinference.tfjs.data.tensors.TFJSTensor +import io.kinference.tfjs.data.tensors.asTensor + +sealed class Sqrt( + name: String, + info: OperatorInfo, + attributes: Map>, + inputs: List, + outputs: List +) : Operator(name, info, attributes, inputs, outputs) { + companion object { + private val DEFAULT_VERSION = VersionInfo(sinceVersion = 6) + + operator fun invoke(name: String, version: Int?, attributes: Map>, inputs: List, outputs: List): Sqrt { + return when (version ?: DEFAULT_VERSION.sinceVersion) { + in SqrtVer6.VERSION.asRange() -> SqrtVer6(name, attributes, inputs, outputs) + else -> error("Unsupported version of Sqrt operator: $version") + } + } + } +} + +class SqrtVer6( + name: String, + attributes: Map>, + inputs: List, + outputs: List +) : Sqrt(name, INFO, attributes, inputs, outputs) { + companion object { + private val TYPE_CONSTRAINTS = FLOAT_DATA_TYPES + + private val INPUTS_INFO = listOf( + IOInfo(0, TYPE_CONSTRAINTS, "X", optional = false) + ) + + private val OUTPUTS_INFO = listOf( + IOInfo(0, TYPE_CONSTRAINTS, "Y", optional = false) + ) + + internal val VERSION = VersionInfo(sinceVersion = 6) + private val INFO = OperatorInfo("Sqrt", emptyMap(), INPUTS_INFO, OUTPUTS_INFO, VERSION, domain = OperatorInfo.DEFAULT_DOMAIN) + } + + + override suspend fun > apply(contexts: Contexts, inputs: List): List { + val input = inputs[0]!!.data as NumberNDArrayTFJS + + return listOf(input.sqrt().asTensor("Y")) + } +} diff --git a/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/math/SqrtTest.kt b/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/math/SqrtTest.kt new file mode 100644 index 000000000..3c18cdd79 --- /dev/null +++ b/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/math/SqrtTest.kt @@ -0,0 +1,19 @@ +package io.kinference.tfjs.operators.math + +import io.kinference.tfjs.runners.TFJSTestEngine +import io.kinference.utils.TestRunner +import kotlin.test.Test + +class SqrtTest { + private fun getTargetPath(dirName: String) = "sqrt/$dirName/" + + @Test + fun test_sqrt() = TestRunner.runTest { + TFJSTestEngine.TFJSAccuracyRunner.runFromResources(getTargetPath("test_sqrt")) + } + + @Test + fun test_sqrt_example() = TestRunner.runTest { + TFJSTestEngine.TFJSAccuracyRunner.runFromResources(getTargetPath("test_sqrt_example")) + } +}