Skip to content

Commit

Permalink
JBAI-6945 [ndarray] Fixed broadcasting logic for batch processing.
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitriyb committed Sep 23, 2024
1 parent 61ff574 commit c917bb0
Showing 1 changed file with 14 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,15 @@ internal fun broadcastTwoTensorsPrimitive(
val batchSize = destBroadcastingShape[shapeIdx]

for (batchIdx in 0 until batchSize) {
val leftScalar = leftBlocks[leftOffset.value][0]
val leftBatchOffset = leftOffset.value + leftOffsets[shapeIdx] * batchIdx
val rightBatchOffset = rightOffset.value + rightOffsets[shapeIdx] * batchIdx
val destBatchOffset = destOffset.value + destOffsets[shapeIdx] * batchIdx

val leftScalar = leftBlocks[leftBatchOffset][0]

for (blockIdx in 0 until destBlocksInRow) {
val destBlock = destBlocks[destOffset.value + blockIdx]
val rightBlock = rightBlocks[rightOffset.value + blockIdx]
val destBlock = destBlocks[destBatchOffset + blockIdx]
val rightBlock = rightBlocks[rightBatchOffset + blockIdx]

for (idx in destBlock.indices) {
destBlock[idx] = op(leftScalar, rightBlock[idx])
Expand All @@ -68,11 +72,15 @@ internal fun broadcastTwoTensorsPrimitive(
val batchSize = destBroadcastingShape[shapeIdx]

for (batchIdx in 0 until batchSize) {
val rightScalar = rightBlocks[rightOffset.value][0]
val leftBatchOffset = leftOffset.value + leftOffsets[shapeIdx] * batchIdx
val rightBatchOffset = rightOffset.value + rightOffsets[shapeIdx] * batchIdx
val destBatchOffset = destOffset.value + destOffsets[shapeIdx] * batchIdx

val rightScalar = rightBlocks[rightBatchOffset][0]

for (blockIdx in 0 until destBlocksInRow) {
val destBlock = destBlocks[destOffset.value + blockIdx]
val leftBlock = leftBlocks[leftOffset.value + blockIdx]
val destBlock = destBlocks[destBatchOffset + blockIdx]
val leftBlock = leftBlocks[leftBatchOffset + blockIdx]

for (idx in destBlock.indices) {
destBlock[idx] = op(leftBlock[idx], rightScalar)
Expand Down

0 comments on commit c917bb0

Please sign in to comment.