Skip to content

Commit

Permalink
KI-48 [core, tfjs] Add SequenceErase tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AnastasiaTuchina committed Jul 25, 2023
1 parent c65cb69 commit 79197c5
Show file tree
Hide file tree
Showing 18 changed files with 73 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ class KIONNXSequence(name: String?, data: List<KIONNXData<*>>, val info: ValueTy

companion object {
fun create(proto: SequenceProto): KIONNXSequence {
val elementTypeInfo = proto.extractTypeInfo() as ValueTypeInfo.SequenceTypeInfo
val elementTypeInfo = proto.extractTypeInfo()
val name = proto.name!!
val data = when (proto.elementType) {
SequenceProto.DataType.TENSOR -> proto.tensorValues.map { KITensor.create(it) }
SequenceProto.DataType.SEQUENCE -> proto.sequenceValues.map { create(it) }
SequenceProto.DataType.MAP -> proto.mapValues.map { KIONNXMap.create(it) }
else -> error("Unsupported sequence element type: ${proto.elementType}")
}
return KIONNXSequence(name, data, elementTypeInfo)
return KIONNXSequence(name, data, ValueTypeInfo.SequenceTypeInfo(elementTypeInfo))
}

internal fun SequenceProto.extractTypeInfo(): ValueTypeInfo = when (this.elementType) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.kinference.operators.seq

import io.kinference.KITestEngine
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SequenceEraseTest {
private fun getTargetPath(dirName: String) = "sequence_erase/$dirName/"

@Test
fun test_sequence_erase_default() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_sequence_erase_default"))
}

@Test
fun test_sequence_erase_positive() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_sequence_erase_positive"))
}

@Test
fun test_sequence_erase_negative() = TestRunner.runTest {
KITestEngine.KIAccuracyRunner.runFromResources(getTargetPath("test_sequence_erase_negative"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ class TFJSSequence(name: String?, data: List<TFJSData<*>>, val info: ValueTypeIn

companion object {
fun create(proto: SequenceProto): TFJSSequence {
val elementTypeInfo = proto.extractTypeInfo() as ValueTypeInfo.SequenceTypeInfo
val elementTypeInfo = proto.extractTypeInfo()
val name = proto.name!!
val data = when (proto.elementType) {
SequenceProto.DataType.TENSOR -> proto.tensorValues.map { TFJSTensor.create(it) }
SequenceProto.DataType.SEQUENCE -> proto.sequenceValues.map { create(it) }
SequenceProto.DataType.MAP -> proto.mapValues.map { TFJSMap.create(it) }
else -> error("Unsupported sequence element type: ${proto.elementType}")
}
return TFJSSequence(name, data, elementTypeInfo)
return TFJSSequence(name, data, ValueTypeInfo.SequenceTypeInfo(elementTypeInfo))
}

internal fun SequenceProto.extractTypeInfo(): ValueTypeInfo = when (this.elementType) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.kinference.tfjs.operators.seq

import io.kinference.tfjs.runners.TFJSTestEngine.TFJSAccuracyRunner
import io.kinference.utils.TestRunner
import kotlin.test.Test

class SequenceEraseTest {
private fun getTargetPath(dirName: String) = "sequence_erase/$dirName/"

@Test
fun test_sequence_erase_default() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_sequence_erase_default"))
}

@Test
fun test_sequence_erase_positive() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_sequence_erase_positive"))
}

@Test
fun test_sequence_erase_negative() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_sequence_erase_negative"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
test_data_set_0/input_0.pb:ONNX_TYPE:ONNX_SEQUENCE
test_data_set_0/output_0.pb:ONNX_TYPE:ONNX_SEQUENCE
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
:
0
input_sequenceoutput_sequence"SequenceErasetest_SequenceEraseZ
input_sequence"

b
output_sequence"

B
ai.onnx
Expand Down
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
test_data_set_0/input_0.pb:ONNX_TYPE:ONNX_SEQUENCE
test_data_set_0/input_1.pb
test_data_set_0/output_0.pb:ONNX_TYPE:ONNX_SEQUENCE
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
���������Bposition
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
test_data_set_0/input_0.pb:ONNX_TYPE:ONNX_SEQUENCE
test_data_set_0/input_1.pb
test_data_set_0/output_0.pb:ONNX_TYPE:ONNX_SEQUENCE
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*Bposition
Binary file not shown.

0 comments on commit 79197c5

Please sign in to comment.