diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/broadcasting/BroadcastTwoArgumentsPrimitive.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/broadcasting/BroadcastTwoArgumentsPrimitive.kt index 90056a8bf..61fc1c076 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/broadcasting/BroadcastTwoArgumentsPrimitive.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/broadcasting/BroadcastTwoArgumentsPrimitive.kt @@ -50,11 +50,15 @@ internal fun broadcastTwoTensorsPrimitive( val batchSize = destBroadcastingShape[shapeIdx] for (batchIdx in 0 until batchSize) { - val leftScalar = leftBlocks[leftOffset.value][0] + val leftBatchOffset = leftOffset.value + leftOffsets[shapeIdx] * batchIdx + val rightBatchOffset = rightOffset.value + rightOffsets[shapeIdx] * batchIdx + val destBatchOffset = destOffset.value + destOffsets[shapeIdx] * batchIdx + + val leftScalar = leftBlocks[leftBatchOffset][0] for (blockIdx in 0 until destBlocksInRow) { - val destBlock = destBlocks[destOffset.value + blockIdx] - val rightBlock = rightBlocks[rightOffset.value + blockIdx] + val destBlock = destBlocks[destBatchOffset + blockIdx] + val rightBlock = rightBlocks[rightBatchOffset + blockIdx] for (idx in destBlock.indices) { destBlock[idx] = op(leftScalar, rightBlock[idx]) @@ -68,11 +72,15 @@ internal fun broadcastTwoTensorsPrimitive( val batchSize = destBroadcastingShape[shapeIdx] for (batchIdx in 0 until batchSize) { - val rightScalar = rightBlocks[rightOffset.value][0] + val leftBatchOffset = leftOffset.value + leftOffsets[shapeIdx] * batchIdx + val rightBatchOffset = rightOffset.value + rightOffsets[shapeIdx] * batchIdx + val destBatchOffset = destOffset.value + destOffsets[shapeIdx] * batchIdx + + val rightScalar = rightBlocks[rightBatchOffset][0] for (blockIdx in 0 until destBlocksInRow) { - val destBlock = destBlocks[destOffset.value + blockIdx] - val leftBlock = leftBlocks[leftOffset.value + blockIdx] + val destBlock = destBlocks[destBatchOffset + blockIdx] + val leftBlock = leftBlocks[leftBatchOffset + blockIdx] for (idx in destBlock.indices) { destBlock[idx] = op(leftBlock[idx], rightScalar)