Skip to content

Commit

Permalink
Merge pull request #109 from JetBrains-Research/sqrt-op
Browse files Browse the repository at this point in the history
KI-39 Sqrt operator
  • Loading branch information
cupertank authored Jul 17, 2023
2 parents a051454 + 5f28272 commit 2860fa2
Show file tree
Hide file tree
Showing 45 changed files with 302 additions and 441 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ object KIOperatorFactory : OperatorFactory<KIONNXData<*>> {
"Softmax" -> Softmax(name, version, attributes, inputs, outputs)
"Split" -> Split(name, version, attributes, inputs, outputs)
"SplitToSequence" -> SplitToSequence(name, version, attributes, inputs, outputs)
"Sqrt" -> Sqrt(name, version, attributes, inputs, outputs)
"Squeeze" -> Squeeze(name, version, attributes, inputs, outputs)
"Sub" -> Sub(name, version, attributes, inputs, outputs)
"Tan" -> Tan(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.acos.acos
import io.kinference.ndarray.extensions.activations.acos
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.acosh.acosh
import io.kinference.ndarray.extensions.activations.acosh
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.asin.asin
import io.kinference.ndarray.extensions.activations.asin
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.asinh.asinh
import io.kinference.ndarray.extensions.activations.asinh
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.atan.atan
import io.kinference.ndarray.extensions.activations.atan
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.atanh.atanh
import io.kinference.ndarray.extensions.activations.atanh
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.cos.cos
import io.kinference.ndarray.extensions.activations.cos
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.cosh.cosh
import io.kinference.ndarray.extensions.activations.cosh
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.tan.tan
import io.kinference.ndarray.extensions.activations.tan
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package io.kinference.core.operators.math

import io.kinference.attribute.Attribute
import io.kinference.core.data.tensor.KITensor
import io.kinference.core.data.tensor.asTensor
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.sqrt
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

sealed class Sqrt(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<KITensor, KITensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 6)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Sqrt {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in SqrtVer6.VERSION.asRange() -> SqrtVer6(name, attributes, inputs, outputs)
else -> error("Unsupported version of Sqrt operator: $version")
}
}
}
}

class SqrtVer6(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Sqrt(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = FLOAT_DATA_TYPES

private val INPUTS_INFO = listOf(
IOInfo(0, TYPE_CONSTRAINTS, "X", optional = false)
)

private val OUTPUTS_INFO = listOf(
IOInfo(0, TYPE_CONSTRAINTS, "Y", optional = false)
)

internal val VERSION = VersionInfo(sinceVersion = 6)
private val INFO = OperatorInfo("Sqrt", emptyMap(), INPUTS_INFO, OUTPUTS_INFO, VERSION, domain = OperatorInfo.DEFAULT_DOMAIN)
}


override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val input = inputs[0]!!.data as NumberNDArrayCore
val output = when(val type = input.type) {
DataType.FLOAT -> (input as FloatNDArray).sqrt()
DataType.DOUBLE -> (input as DoubleNDArray).sqrt()
else -> error("Unsupported data type: $type")
}

return listOf(output.asTensor("Y"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import io.kinference.core.data.tensor.asTensor
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.ceil.ceil
import io.kinference.ndarray.extensions.ceil
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import io.kinference.core.data.tensor.asTensor
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.floor.floor
import io.kinference.ndarray.extensions.floor
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import io.kinference.core.data.tensor.asTensor
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.isNaN.isNaN
import io.kinference.ndarray.extensions.isNaN
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.operators.math

import io.kinference.KITestEngine
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SqrtTest {
private fun getTargetPath(dirName: String) = "sqrt/$dirName/"

@Test
fun test_sqrt() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_sqrt"))
}

@Test
fun test_sqrt_example() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_sqrt_example"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ object TFJSOperatorFactory : OperatorFactory<TFJSData<*>> {
"Softmax" -> Softmax(name, version, attributes, inputs, outputs)
"Split" -> Split(name, version, attributes, inputs, outputs)
"SplitToSequence" -> SplitToSequence(name, version, attributes, inputs, outputs)
"Sqrt" -> Sqrt(name, version, attributes, inputs, outputs)
"Squeeze" -> Squeeze(name, version, attributes, inputs, outputs)
"Sub" -> Sub(name, version, attributes, inputs, outputs)
"Tan" -> Tan(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package io.kinference.tfjs.operators.math

import io.kinference.attribute.Attribute
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.NumberNDArrayTFJS
import io.kinference.ndarray.extensions.sqrt
import io.kinference.operator.*
import io.kinference.tfjs.data.tensors.TFJSTensor
import io.kinference.tfjs.data.tensors.asTensor

sealed class Sqrt(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<TFJSTensor, TFJSTensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 6)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Sqrt {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in SqrtVer6.VERSION.asRange() -> SqrtVer6(name, attributes, inputs, outputs)
else -> error("Unsupported version of Sqrt operator: $version")
}
}
}
}

class SqrtVer6(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Sqrt(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = FLOAT_DATA_TYPES

private val INPUTS_INFO = listOf(
IOInfo(0, TYPE_CONSTRAINTS, "X", optional = false)
)

private val OUTPUTS_INFO = listOf(
IOInfo(0, TYPE_CONSTRAINTS, "Y", optional = false)
)

internal val VERSION = VersionInfo(sinceVersion = 6)
private val INFO = OperatorInfo("Sqrt", emptyMap(), INPUTS_INFO, OUTPUTS_INFO, VERSION, domain = OperatorInfo.DEFAULT_DOMAIN)
}


override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<TFJSTensor?>): List<TFJSTensor?> {
val input = inputs[0]!!.data as NumberNDArrayTFJS

return listOf(input.sqrt().asTensor("Y"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.tfjs.operators.math

import io.kinference.tfjs.runners.TFJSTestEngine
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SqrtTest {
private fun getTargetPath(dirName: String) = "sqrt/$dirName/"

@Test
fun test_sqrt() = TestRunner.runTest {
TFJSTestEngine.TFJSAccuracyRunner.runFromResources(getTargetPath("test_sqrt"))
}

@Test
fun test_sqrt_example() = TestRunner.runTest {
TFJSTestEngine.TFJSAccuracyRunner.runFromResources(getTargetPath("test_sqrt_example"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
@file:GeneratePrimitives(DataType.ALL)

package io.kinference.ndarray.extensions

import io.kinference.ndarray.arrays.*
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import io.kinference.primitives.types.PrimitiveType

fun PrimitiveNDArray.applyElementWise(func: (PrimitiveType) -> PrimitiveType): MutablePrimitiveNDArray {
val output = MutablePrimitiveNDArray(strides)

val inputBlockIter = array.blocks.iterator()
val outputBlockIter = output.array.blocks.iterator()
val blockSize = output.array.blockSize

repeat(output.array.blocksNum) {
val inputBlock = inputBlockIter.next()
val outputBlock = outputBlockIter.next()

for (idx in 0 until blockSize) {
outputBlock[idx] = func(inputBlock[idx])
}
}

return output
}

fun PrimitiveNDArray.predicateElementWise(predicate: (PrimitiveType) -> Boolean): BooleanNDArray {
val output = MutableBooleanNDArray(strides)

val inputBlockIter = array.blocks.iterator()
val outputBlockIter = output.array.blocks.iterator()
val blockSize = output.array.blockSize

repeat(output.array.blocksNum) {
val inputBlock = inputBlockIter.next()
val outputBlock = outputBlockIter.next()

for (idx in 0 until blockSize) {
outputBlock[idx] = predicate(inputBlock[idx])
}
}

return output
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@file:GeneratePrimitives(
DataType.DOUBLE,
DataType.FLOAT
)

package io.kinference.ndarray.extensions

import io.kinference.ndarray.arrays.BooleanNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.stubs.*
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import kotlin.math.*

fun PrimitiveNDArray.isNaN(): BooleanNDArray = predicateElementWise { it.isNaN() }

fun PrimitiveNDArray.ceil(): PrimitiveNDArray = applyElementWise { ceil(it) }
fun PrimitiveNDArray.floor(): PrimitiveNDArray = applyElementWise { floor(it) }

fun PrimitiveNDArray.sqrt(): PrimitiveNDArray = applyElementWise { sqrt(it) }
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,13 @@

package io.kinference.ndarray.extensions.abs

import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.extensions.applyElementWise
import io.kinference.ndarray.stubs.abs
import io.kinference.primitives.annotations.GenerateNameFromPrimitives
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import kotlin.math.abs

@GenerateNameFromPrimitives
internal fun absPrimitive(array: PrimitiveNDArray): PrimitiveNDArray {
val output = MutablePrimitiveNDArray(array.strides)

val inputBlockIter = array.array.blocks.iterator()
val outputBlockIter = output.array.blocks.iterator()
val blockSize = output.array.blockSize

repeat(output.array.blocksNum) {
val inputBlock = inputBlockIter.next()
val outputBlock = outputBlockIter.next()

for (idx in 0 until blockSize) {
outputBlock[idx] = abs(inputBlock[idx])
}
}

return output
}
internal fun absPrimitive(array: PrimitiveNDArray): PrimitiveNDArray = array.applyElementWise { abs(it) }
Loading

0 comments on commit 2860fa2

Please sign in to comment.