Skip to content

Commit

Permalink
Merge branch 'master' into sum-op
Browse files Browse the repository at this point in the history
# Conflicts:
#	inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/KIOperatorFactory.kt
#	inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt
  • Loading branch information
cupertank committed Jul 17, 2023
2 parents f44a145 + a051454 commit e043eac
Show file tree
Hide file tree
Showing 68 changed files with 315 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ object KIOperatorFactory : OperatorFactory<KIONNXData<*>> {
"Squeeze" -> Squeeze(name, version, attributes, inputs, outputs)
"Sub" -> Sub(name, version, attributes, inputs, outputs)
"Sum" -> Sum(name, version, attributes, inputs, outputs)
"Tan" -> Tan(name, version, attributes, inputs, outputs)
"Tanh" -> Tanh(name, version, attributes, inputs, outputs)
"Tile" -> Tile(name, version, attributes, inputs, outputs)
"TopK" -> TopK(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.kinference.core.operators.activations

import io.kinference.attribute.Attribute
import io.kinference.core.KIONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.activations.tan.tan
import io.kinference.operator.*
import io.kinference.primitives.types.DataType

sealed class Tan(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>, outputs: List<String>
) : Activation(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 7)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Tan {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in TanVer7.VERSION.asRange() -> TanVer7(name, attributes, inputs, outputs)
else -> error("Unsupported version of Tan operator: $version")
}
}
}
}

class TanVer7(
name: String,
attributes: Map<String, Attribute<Any>> = emptyMap(),
inputs: List<String>,
outputs: List<String>
) : Tan(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = FLOAT_DATA_TYPES

private val INPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "input", optional = false))
private val OUTPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "output", optional = false))

internal val VERSION = VersionInfo(sinceVersion = 7)
private val INFO = OperatorInfo("Tan", emptySet(), INPUT_INFO, OUTPUT_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun activate(input: NDArrayCore, contexts: Contexts<KIONNXData<*>>): NDArrayCore {
return when (val type = input.type) {
DataType.FLOAT -> (input as FloatNDArray).tan()
DataType.DOUBLE -> (input as DoubleNDArray).tan()
else -> error("Unsupported data type : $type")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.operators.activations

import io.kinference.KITestEngine
import io.kinference.utils.TestRunner
import kotlin.test.Test

class TanTest {
private fun getTargetPath(dirName: String) = "tan/$dirName/"

@Test
fun test_tanh_example() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_tan_example"))
}

@Test
fun test_tanh() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_tan"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ object TFJSOperatorFactory : OperatorFactory<TFJSData<*>> {
"Squeeze" -> Squeeze(name, version, attributes, inputs, outputs)
"Sub" -> Sub(name, version, attributes, inputs, outputs)
"Sum" -> Sum(name, version, attributes, inputs, outputs)
"Tan" -> Tan(name, version, attributes, inputs, outputs)
"Tanh" -> Tanh(name, version, attributes, inputs, outputs)
"Tile" -> Tile(name, version, attributes, inputs, outputs)
"Transpose" -> Transpose(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.kinference.tfjs.operators.activations

import io.kinference.attribute.Attribute
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.tan
import io.kinference.operator.*
import io.kinference.tfjs.data.tensors.TFJSTensor
import io.kinference.tfjs.data.tensors.asTensor

sealed class Tan(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<TFJSTensor, TFJSTensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 7)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Tan {
return when (version ?: DEFAULT_VERSION.sinceVersion) {
in TanVer7.VERSION.asRange() -> TanVer7(name, attributes, inputs, outputs)
else -> error("Unsupported version of Tan operator: $version")
}
}
}
}


class TanVer7(
name: String,
attributes: Map<String, Attribute<Any>> = emptyMap(),
inputs: List<String>,
outputs: List<String>
) : Tan(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = FLOAT_DATA_TYPES

private val INPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "input", optional = false))
private val OUTPUT_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "output", optional = false))

internal val VERSION = VersionInfo(sinceVersion = 7)
private val INFO = OperatorInfo("Tan", emptySet(), INPUT_INFO, OUTPUT_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<TFJSTensor?>): List<TFJSTensor?> {
val input = inputs[0]!!.data as NumberNDArrayTFJS
return listOf(input.tan().asTensor("output"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.kinference.tfjs.operators.activations

import io.kinference.tfjs.runners.TFJSTestEngine.TFJSAccuracyRunner
import io.kinference.utils.TestRunner
import kotlin.test.Test

class TanTest {
private fun getTargetPath(dirName: String) = "tan/$dirName/"

@Test
fun test_tanh_example() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tan_example"))
}

@Test
fun test_tanh() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tan"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import io.kinference.ndarray.extensions.argMinMax.ArgMinMaxMode
import io.kinference.ndarray.extensions.argMinMax.argMinMaxPrimitive
import io.kinference.ndarray.extensions.dot.*
import io.kinference.ndarray.extensions.softmax.softmax
import io.kinference.ndarray.stubs.isCompatibleWith
import io.kinference.primitives.annotations.GenerateNameFromPrimitives
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.*
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import io.kinference.ndarray.*
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.pointers.accept
import io.kinference.ndarray.arrays.pointers.acceptWithRecursive
import io.kinference.ndarray.stubs.*
import io.kinference.ndarray.arrays.tiled.*
import io.kinference.primitives.annotations.*
import io.kinference.primitives.types.*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package io.kinference.ndarray.extensions

import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.stubs.pow
import io.kinference.ndarray.extensions.utils.calculateInnerShapeSize
import io.kinference.ndarray.extensions.utils.divCeil
import io.kinference.primitives.annotations.GeneratePrimitives
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package io.kinference.ndarray.extensions.abs

import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.stubs.abs
import io.kinference.primitives.annotations.GenerateNameFromPrimitives
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package io.kinference.ndarray.extensions.abs

import io.kinference.primitives.types.PrimitiveType
import kotlin.math.abs

internal fun abs(x: Short) = abs(x.toInt()).toShort()
internal fun abs(x: Byte) = abs(x.toInt()).toByte()

internal inline fun abs(x: PrimitiveType): PrimitiveType = throw UnsupportedOperationException()
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package io.kinference.ndarray.extensions.activations.acos

import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.stubs.acos
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import kotlin.math.acos
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package io.kinference.ndarray.extensions.activations.acosh

import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.stubs.acosh
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import kotlin.math.acosh
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package io.kinference.ndarray.extensions.activations.asin

import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.stubs.asin
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import kotlin.math.asin
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package io.kinference.ndarray.extensions.activations.asinh

import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.stubs.asinh
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import kotlin.math.asinh
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package io.kinference.ndarray.extensions.activations.atan

import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.stubs.atan
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import kotlin.math.atan
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package io.kinference.ndarray.extensions.activations.atanh

import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.stubs.atanh
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import kotlin.math.atanh
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package io.kinference.ndarray.extensions.activations.cos

import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.stubs.cos
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import kotlin.math.cos
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package io.kinference.ndarray.extensions.activations.cosh

import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.stubs.cosh
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import kotlin.math.cosh
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
@file:GeneratePrimitives(
DataType.FLOAT,
DataType.DOUBLE
)

package io.kinference.ndarray.extensions.activations.tan

import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.stubs.tan
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import kotlin.math.tan

fun PrimitiveNDArray.tan(): PrimitiveNDArray {
val output = MutablePrimitiveNDArray(this.strides)

val outputIter = output.array.blocks.iterator()
val inputIter = this.array.blocks.iterator()
val blocksNum = this.array.blocksNum

repeat(blocksNum) {
val inputBlock = inputIter.next()
val outputBlock = outputIter.next()

for (idx in outputBlock.indices) {
outputBlock[idx] = tan(inputBlock[idx])
}
}

return output
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package io.kinference.ndarray.extensions.bitwise.and

import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.broadcasting.broadcastTwoTensorsPrimitive
import io.kinference.ndarray.stubs.and
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import io.kinference.primitives.types.PrimitiveType
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package io.kinference.ndarray.extensions.bitwise.not

import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.stubs.inv
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import kotlin.experimental.inv
Expand Down
Loading

0 comments on commit e043eac

Please sign in to comment.