From f5c2d29ae6db693cefdb8e5f2a45c8f2637352e1 Mon Sep 17 00:00:00 2001 From: dmitriyb Date: Tue, 15 Oct 2024 10:50:39 +0200 Subject: [PATCH 1/5] JBAI-197 [ndarray] Refactored dotTransposedWithAlpha: eliminated temp allocations and iterator creations. --- .../ndarray/extensions/PrimitiveExtensions.kt | 29 ++++--------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/PrimitiveExtensions.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/PrimitiveExtensions.kt index cbf651bc..0b367b7b 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/PrimitiveExtensions.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/PrimitiveExtensions.kt @@ -134,50 +134,33 @@ internal suspend fun PrimitiveNDArray.dotTransposedWithAlpha(alpha: Double, othe val t = this.shape[1] val m = other.shape[0] - val lrBlockSize = this.array.blockSize - val leftBlocks = this.array.blocks val rightBlocks = other.array.blocks val rowFlop = t * m - - /* TODO: (dmitriyb) this is temporary commented. On GEC performance test we have large inputs that cause out of memory exceptions - We need to implement controlling mechanism which will prevent ArrayDispatcher of enormous grow*/ - - // This approach when arrays acquired before parallelizeByBlocks() is faster -// val coroutineCount = countCoroutinesByData(rowFlop, n, 262144) -// val containerArray = ArrayDispatcher.getArraysAndMarkers(PrimitiveTiledArray.type, lrBlockSize, m * coroutineCount) -// val mSumsArrays = Array(coroutineCount) { index -> -// Array(m) { mIndex -> -// (containerArray[index * m + mIndex] as PrimitiveArrayContainer).array -// } -// } - // Constant 262144 was precomputed on M1 Max processor // With this constant two launches work faster than single thread without launches // TODO: (cupertank) Remove constants - // TODO: (dmitriyb) Implement concurrent array retrieve with a separate structure from ArraysDispatcher parallelizeByRows(rowFlop, n, 262144) { nStart: Int, nEnd: Int, _ -> - val tempSum = PrimitiveArray(lrBlockSize) val destPointer = destination.array.pointer() for (i in nStart until nEnd) { val leftBlockOffset = i * lrBlocksInRow - val rightBlockIter = rightBlocks.iterator() + var rightBlockIndex = 0 destPointer.linearIndex = i * m for (k in 0 until m) { + var totalSum = PrimitiveConstants.ZERO for (lrBlock in 0 until lrBlocksInRow) { val leftBlock = leftBlocks[leftBlockOffset + lrBlock] - val rightBlock = rightBlockIter.next() + val rightBlock = rightBlocks[rightBlockIndex++] - for (j in tempSum.indices) { - tempSum[j] += leftBlock[j] * rightBlock[j] + for (j in leftBlock.indices) { + totalSum += leftBlock[j] * rightBlock[j] } } - destPointer.setAndIncrement(tempSum.sum() * alpha) - tempSum.fill(PrimitiveConstants.ZERO) + destPointer.setAndIncrement((totalSum * alpha).toPrimitive()) } } } From d1efc5febf4506ac7c34669e6c875b7b21b08dd3 Mon Sep 17 00:00:00 2001 From: dmitriyb Date: Tue, 15 Oct 2024 10:59:25 +0200 Subject: [PATCH 2/5] JBAI-197 [ndarray] Refactored axis and block index checks in NDArray view and elimination of an iterator object in dotParallelN. --- .../kotlin/io/kinference/ndarray/arrays/PrimitiveNDArray.kt | 4 ++-- .../ndarray/extensions/dot/PrimitiveDotParallelN.kt | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/PrimitiveNDArray.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/PrimitiveNDArray.kt index f1bd91b4..59896ff8 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/PrimitiveNDArray.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/PrimitiveNDArray.kt @@ -42,8 +42,8 @@ internal open class PrimitiveNDArray(array: PrimitiveTiledArray, strides: Stride } override fun view(vararg axes: Int): PrimitiveNDArray { - for ((i, axis) in axes.withIndex()) { - require(shape[i] > axis) + for (i in axes.indices) { + require(shape[i] > axes[i]) } val offset = axes.foldIndexed(0) { index, acc, i -> acc + i * strides.strides[index] } diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/dot/PrimitiveDotParallelN.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/dot/PrimitiveDotParallelN.kt index 4f564549..8b76bdab 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/dot/PrimitiveDotParallelN.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/dot/PrimitiveDotParallelN.kt @@ -31,7 +31,7 @@ internal suspend fun dotParallelN(left: PrimitiveNDArray, right: PrimitiveNDArra for (i in nStart until nEnd) { val leftBlockOffset = i * lBlocksInRow val destBlockOffset = i * rdBlocksInRow - val rightBlockIterator = rightBlocks.iterator() + var rightBlockIndex = 0 for (lCol in 0 until lBlocksInRow) { val leftBlock = leftBlocks[leftBlockOffset + lCol] @@ -41,7 +41,7 @@ internal suspend fun dotParallelN(left: PrimitiveNDArray, right: PrimitiveNDArra for (rdCol in 0 until rdBlocksInRow) { val destBlock = destBlocks[destBlockOffset + rdCol] - val rightBlock = rightBlockIterator.next() + val rightBlock = rightBlocks[rightBlockIndex++] for (j in destBlock.indices) { destBlock[j] = (destBlock[j] + temp * rightBlock[j]).toPrimitive() From d3bf6dbaaac32674cac714d65c12dab3929a81f9 Mon Sep 17 00:00:00 2001 From: dmitriyb Date: Tue, 15 Oct 2024 11:00:21 +0200 Subject: [PATCH 3/5] JBAI-197 [ndarray] Optimized GELU operation by reusing temporary blocks. --- .../ndarray/extensions/gelu/BiasGeluPrimitive.kt | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/gelu/BiasGeluPrimitive.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/gelu/BiasGeluPrimitive.kt index 9ba08ddb..c3eaafb3 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/gelu/BiasGeluPrimitive.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/gelu/BiasGeluPrimitive.kt @@ -3,7 +3,8 @@ package io.kinference.ndarray.extensions.gelu import io.kinference.ndarray.* import io.kinference.ndarray.arrays.* -import io.kinference.ndarray.arrays.tiled.PrimitiveTiledArray +import io.kinference.ndarray.arrays.memory.contexts.AutoAllocatorContext +import io.kinference.ndarray.arrays.memory.storage.* import io.kinference.ndarray.extensions.constants.PrimitiveConstants import io.kinference.ndarray.stubs.absoluteValue import io.kinference.ndarray.stubs.pow @@ -11,6 +12,7 @@ import io.kinference.ndarray.math.* import io.kinference.primitives.annotations.GenerateNameFromPrimitives import io.kinference.primitives.annotations.GeneratePrimitives import io.kinference.primitives.types.* +import kotlin.coroutines.coroutineContext import kotlin.math.* @GenerateNameFromPrimitives @@ -22,12 +24,19 @@ internal suspend fun computeGeluPrimitive(input: PrimitiveNDArray, bias: Primiti val blockSize = input.array.blockSize + val coroutineCount = countCoroutinesByData(blockSize, inputBlocks.size, 2048) + val temporaryBlocks = coroutineContext[AutoAllocatorContext]?.getPrimitiveBlock(coroutineCount, blockSize) + ?: Array(coroutineCount) { PrimitiveArray(blockSize) } + val temporaryBlocksAbs = coroutineContext[AutoAllocatorContext]?.getPrimitiveBlock(coroutineCount, blockSize) + ?: Array(coroutineCount) { PrimitiveArray(blockSize) } + + // Constant 2048 was precomputed on M1 Max processor // With this constant two launches work faster than single thread without launches // TODO: (cupertank) Remove constants parallelizeByBlocks(blockSize, inputBlocks.size, 2048) { blockStart, blockEnd, coroutineIndex -> - val temporaryBlock = PrimitiveArray(blockSize) - val temporaryBlockAbs = PrimitiveArray(blockSize) + val temporaryBlock = temporaryBlocks[coroutineIndex] + val temporaryBlockAbs = temporaryBlocksAbs[coroutineIndex] for (blockIdx in blockStart until blockEnd) { val outputBlock = outputBlocks[blockIdx] From cca1f0f578ec6cb8a9eb0c395c8e1d77695b9c60 Mon Sep 17 00:00:00 2001 From: dmitriyb Date: Tue, 15 Oct 2024 13:44:45 +0200 Subject: [PATCH 4/5] JBAI-197 [ndarray] Optimized FastGelu operation by reusing temporary blocks. --- .../ndarray/extensions/gelu/FastGeluPrimitive.kt | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/gelu/FastGeluPrimitive.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/gelu/FastGeluPrimitive.kt index 32a1e6e0..f6a1603a 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/gelu/FastGeluPrimitive.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/extensions/gelu/FastGeluPrimitive.kt @@ -6,6 +6,8 @@ package io.kinference.ndarray.extensions.gelu import io.kinference.ndarray.arrays.* import io.kinference.ndarray.arrays.MutablePrimitiveNDArray import io.kinference.ndarray.arrays.PrimitiveNDArray +import io.kinference.ndarray.arrays.memory.contexts.AutoAllocatorContext +import io.kinference.ndarray.arrays.memory.storage.* import io.kinference.ndarray.arrays.tiled.PrimitiveTiledArray import io.kinference.ndarray.countCoroutinesByData import io.kinference.ndarray.parallelizeByBlocks @@ -16,6 +18,7 @@ import io.kinference.ndarray.math.FastMath import io.kinference.ndarray.math.exp import io.kinference.primitives.annotations.GenerateNameFromPrimitives import io.kinference.primitives.annotations.GeneratePrimitives +import kotlin.coroutines.coroutineContext import kotlin.math.* @GenerateNameFromPrimitives @@ -27,11 +30,15 @@ internal suspend fun fastGeluPrimitive(input: PrimitiveNDArray, bias: PrimitiveN val blockSize = input.array.blockSize + val coroutineCount = countCoroutinesByData(blockSize, inputBlocks.size, 2048) + val temporaryBlocksExp = coroutineContext[AutoAllocatorContext]?.getPrimitiveBlock(coroutineCount, blockSize) + ?: Array(coroutineCount) { PrimitiveArray(blockSize) } + // Constant 2048 was precomputed on M1 Max processor // With this constant two launches work faster than single thread without launches // TODO: (cupertank) Remove constants - parallelizeByBlocks(blockSize, inputBlocks.size, 2048) { blockStart, blockEnd, _ -> - val temporaryBlockExp = PrimitiveArray(blockSize) + parallelizeByBlocks(blockSize, inputBlocks.size, 2048) { blockStart, blockEnd, coroutineIndex -> + val temporaryBlockExp = temporaryBlocksExp[coroutineIndex] for (blockIdx in blockStart until blockEnd) { val outputBlock = outputBlocks[blockIdx] val block = inputBlocks[blockIdx] From 80e31066809ebc5366bd775b806bdac30b374ef7 Mon Sep 17 00:00:00 2001 From: dmitriyb Date: Tue, 15 Oct 2024 15:22:47 +0200 Subject: [PATCH 5/5] JBAI-197 [ndarray] Optimized broadcasting functions by avoiding some allocations. --- .../ndarray/broadcasting/Broadcasting.kt | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/broadcasting/Broadcasting.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/broadcasting/Broadcasting.kt index 7e2e1855..a86eb3f4 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/broadcasting/Broadcasting.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/broadcasting/Broadcasting.kt @@ -7,6 +7,7 @@ import io.kinference.primitives.types.DataType // TODO remove to different module fun unsqueezeFirst(shape: IntArray, newShapeSize: Int): IntArray { val wrapSize = newShapeSize - shape.size + if (wrapSize == 0) return shape val wrappedShape = IntArray(newShapeSize) wrappedShape.fill(1, 0, wrapSize) @@ -35,7 +36,11 @@ object Broadcasting { destination: MutableNDArrayCore, op: suspend (List, MutableNDArrayCore) -> Unit ): MutableNDArrayCore { - val wrappedInputs = inputs.map { it.reshape(unsqueezeFirst(it.shape, destination.shape.size)) } + val destRank = destination.shape.size + val wrappedInputs = inputs.map { input -> + if (input.shape.size == destRank) input + else input.reshape(unsqueezeFirst(input.shape, destRank)) + } broadcast(wrappedInputs, destination, op) return destination @@ -101,29 +106,43 @@ object Broadcasting { destination: MutableNDArrayCore, recurrentBack: suspend (List, MutableNDArrayCore) -> Unit ) { + val numInputs = inputs.size val indexedInputs = inputs.withIndex() val (arraysWithOne, arraysWithoutOne) = indexedInputs.partition { it.value.shape[0] == 1 } + val mergedInputs = MutableList(numInputs) { inputs[0] } + if (destination.shape.size == 1) { val broadcastSize = destination.shape.last() - val broadcastArraysWithOne = arraysWithOne.map { - val value = allocateNDArray(it.value.type, Strides(intArrayOf(broadcastSize))) - it.copy(value = value.apply { fill(it.value.singleValue()) }) + val broadcastArraysWithOne = arraysWithOne.map { indexedInput -> + val value = allocateNDArray(indexedInput.value.type, Strides(intArrayOf(broadcastSize))) + value.apply { fill(indexedInput.value.singleValue()) } + } + + arraysWithOne.forEachIndexed { i, indexedInput -> + mergedInputs[indexedInput.index] = broadcastArraysWithOne[i] + } + + for (indexedInput in arraysWithoutOne) { + mergedInputs[indexedInput.index] = indexedInput.value } - val mergedInputs = broadcastArraysWithOne.plus(arraysWithoutOne).sortedBy { it.index }.map { it.value } return recurrentBack(mergedInputs, destination) } - val viewedArraysWithOne = arraysWithOne.map { it.copy(value = it.value.view(0)) } + val fixedViewsWithOne = arraysWithOne.map { it.copy(value = it.value.view(0)) } - for (i in 0 until destination.shape[0]) { - val viewedArraysWithoutOne = arraysWithoutOne.map { it.copy(value = it.value.view(i)) } - val viewedDestination = destination.viewMutable(i) + for (indexedInput in fixedViewsWithOne) { + mergedInputs[indexedInput.index] = indexedInput.value + } - val mergedViewedInputs = viewedArraysWithOne.plus(viewedArraysWithoutOne).sortedBy { it.index }.map { it.value } + for (i in 0 until destination.shape[0]) { + for (indexedInput in arraysWithoutOne) { + mergedInputs[indexedInput.index] = indexedInput.value.view(i) + } - recurrentBack(mergedViewedInputs, viewedDestination) + val viewedDestination = destination.viewMutable(i) + recurrentBack(mergedInputs, viewedDestination) } } }