Skip to content

Commit

Permalink
KI-39 [tfjs] Add Sqrt operator
Browse files Browse the repository at this point in the history
  • Loading branch information
AnastasiaTuchina committed Jul 14, 2023
1 parent 75d1976 commit 5f28272
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ object TFJSOperatorFactory : OperatorFactory<TFJSData<*>> {
"Softmax" -> Softmax(name, version, attributes, inputs, outputs)
"Split" -> Split(name, version, attributes, inputs, outputs)
"SplitToSequence" -> SplitToSequence(name, version, attributes, inputs, outputs)
"Sqrt" -> Sqrt(name, version, attributes, inputs, outputs)
"Squeeze" -> Squeeze(name, version, attributes, inputs, outputs)
"Sub" -> Sub(name, version, attributes, inputs, outputs)
"Tan" -> Tan(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package io.kinference.tfjs.operators.math

import io.kinference.attribute.Attribute
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.NumberNDArrayTFJS
import io.kinference.ndarray.extensions.sqrt
import io.kinference.operator.*
import io.kinference.tfjs.data.tensors.TFJSTensor
import io.kinference.tfjs.data.tensors.asTensor

sealed class Sqrt(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<TFJSTensor, TFJSTensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 6)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Sqrt {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in SqrtVer6.VERSION.asRange() -> SqrtVer6(name, attributes, inputs, outputs)
else -> error("Unsupported version of Sqrt operator: $version")
}
}
}
}

class SqrtVer6(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Sqrt(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = FLOAT_DATA_TYPES

private val INPUTS_INFO = listOf(
IOInfo(0, TYPE_CONSTRAINTS, "X", optional = false)
)

private val OUTPUTS_INFO = listOf(
IOInfo(0, TYPE_CONSTRAINTS, "Y", optional = false)
)

internal val VERSION = VersionInfo(sinceVersion = 6)
private val INFO = OperatorInfo("Sqrt", emptyMap(), INPUTS_INFO, OUTPUTS_INFO, VERSION, domain = OperatorInfo.DEFAULT_DOMAIN)
}


override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<TFJSTensor?>): List<TFJSTensor?> {
val input = inputs[0]!!.data as NumberNDArrayTFJS

return listOf(input.sqrt().asTensor("Y"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.tfjs.operators.math

import io.kinference.tfjs.runners.TFJSTestEngine
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SqrtTest {
private fun getTargetPath(dirName: String) = "sqrt/$dirName/"

@Test
fun test_sqrt() = TestRunner.runTest {
TFJSTestEngine.TFJSAccuracyRunner.runFromResources(getTargetPath("test_sqrt"))
}

@Test
fun test_sqrt_example() = TestRunner.runTest {
TFJSTestEngine.TFJSAccuracyRunner.runFromResources(getTargetPath("test_sqrt_example"))
}
}

0 comments on commit 5f28272

Please sign in to comment.