Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor broadcasting + gelu and dot transposed improvements #205

Merged
merged 5 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ internal open class PrimitiveNDArray(array: PrimitiveTiledArray, strides: Stride
}

override fun view(vararg axes: Int): PrimitiveNDArray {
for ((i, axis) in axes.withIndex()) {
require(shape[i] > axis)
for (i in axes.indices) {
require(shape[i] > axes[i])
}

val offset = axes.foldIndexed(0) { index, acc, i -> acc + i * strides.strides[index] }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import io.kinference.primitives.types.DataType
// TODO remove to different module
fun unsqueezeFirst(shape: IntArray, newShapeSize: Int): IntArray {
val wrapSize = newShapeSize - shape.size
if (wrapSize == 0) return shape

val wrappedShape = IntArray(newShapeSize)
wrappedShape.fill(1, 0, wrapSize)
Expand Down Expand Up @@ -35,7 +36,11 @@ object Broadcasting {
destination: MutableNDArrayCore,
op: suspend (List<NDArrayCore>, MutableNDArrayCore) -> Unit
): MutableNDArrayCore {
val wrappedInputs = inputs.map { it.reshape(unsqueezeFirst(it.shape, destination.shape.size)) }
val destRank = destination.shape.size
val wrappedInputs = inputs.map { input ->
if (input.shape.size == destRank) input
else input.reshape(unsqueezeFirst(input.shape, destRank))
}

broadcast(wrappedInputs, destination, op)
return destination
Expand Down Expand Up @@ -101,29 +106,43 @@ object Broadcasting {
destination: MutableNDArrayCore,
recurrentBack: suspend (List<NDArrayCore>, MutableNDArrayCore) -> Unit
) {
val numInputs = inputs.size
val indexedInputs = inputs.withIndex()
val (arraysWithOne, arraysWithoutOne) = indexedInputs.partition { it.value.shape[0] == 1 }

val mergedInputs = MutableList<NDArrayCore>(numInputs) { inputs[0] }

if (destination.shape.size == 1) {
val broadcastSize = destination.shape.last()
val broadcastArraysWithOne = arraysWithOne.map {
val value = allocateNDArray(it.value.type, Strides(intArrayOf(broadcastSize)))
it.copy(value = value.apply { fill(it.value.singleValue()) })
val broadcastArraysWithOne = arraysWithOne.map { indexedInput ->
val value = allocateNDArray(indexedInput.value.type, Strides(intArrayOf(broadcastSize)))
value.apply { fill(indexedInput.value.singleValue()) }
}

arraysWithOne.forEachIndexed { i, indexedInput ->
mergedInputs[indexedInput.index] = broadcastArraysWithOne[i]
}

for (indexedInput in arraysWithoutOne) {
mergedInputs[indexedInput.index] = indexedInput.value
}
val mergedInputs = broadcastArraysWithOne.plus(arraysWithoutOne).sortedBy { it.index }.map { it.value }

return recurrentBack(mergedInputs, destination)
}

val viewedArraysWithOne = arraysWithOne.map { it.copy(value = it.value.view(0)) }
val fixedViewsWithOne = arraysWithOne.map { it.copy(value = it.value.view(0)) }

for (i in 0 until destination.shape[0]) {
val viewedArraysWithoutOne = arraysWithoutOne.map { it.copy(value = it.value.view(i)) }
val viewedDestination = destination.viewMutable(i)
for (indexedInput in fixedViewsWithOne) {
mergedInputs[indexedInput.index] = indexedInput.value
}

val mergedViewedInputs = viewedArraysWithOne.plus(viewedArraysWithoutOne).sortedBy { it.index }.map { it.value }
for (i in 0 until destination.shape[0]) {
for (indexedInput in arraysWithoutOne) {
mergedInputs[indexedInput.index] = indexedInput.value.view(i)
}

recurrentBack(mergedViewedInputs, viewedDestination)
val viewedDestination = destination.viewMutable(i)
recurrentBack(mergedInputs, viewedDestination)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,50 +134,33 @@ internal suspend fun PrimitiveNDArray.dotTransposedWithAlpha(alpha: Double, othe
val t = this.shape[1]
val m = other.shape[0]

val lrBlockSize = this.array.blockSize

val leftBlocks = this.array.blocks
val rightBlocks = other.array.blocks
val rowFlop = t * m


/* TODO: (dmitriyb) this is temporary commented. On GEC performance test we have large inputs that cause out of memory exceptions
We need to implement controlling mechanism which will prevent ArrayDispatcher of enormous grow*/

// This approach when arrays acquired before parallelizeByBlocks() is faster
// val coroutineCount = countCoroutinesByData(rowFlop, n, 262144)
// val containerArray = ArrayDispatcher.getArraysAndMarkers(PrimitiveTiledArray.type, lrBlockSize, m * coroutineCount)
// val mSumsArrays = Array(coroutineCount) { index ->
// Array(m) { mIndex ->
// (containerArray[index * m + mIndex] as PrimitiveArrayContainer).array
// }
// }

// Constant 262144 was precomputed on M1 Max processor
// With this constant two launches work faster than single thread without launches
// TODO: (cupertank) Remove constants
// TODO: (dmitriyb) Implement concurrent array retrieve with a separate structure from ArraysDispatcher
parallelizeByRows(rowFlop, n, 262144) { nStart: Int, nEnd: Int, _ ->
val tempSum = PrimitiveArray(lrBlockSize)
val destPointer = destination.array.pointer()
for (i in nStart until nEnd) {
val leftBlockOffset = i * lrBlocksInRow
val rightBlockIter = rightBlocks.iterator()
var rightBlockIndex = 0

destPointer.linearIndex = i * m

for (k in 0 until m) {
var totalSum = PrimitiveConstants.ZERO
for (lrBlock in 0 until lrBlocksInRow) {
val leftBlock = leftBlocks[leftBlockOffset + lrBlock]
val rightBlock = rightBlockIter.next()
val rightBlock = rightBlocks[rightBlockIndex++]

for (j in tempSum.indices) {
tempSum[j] += leftBlock[j] * rightBlock[j]
for (j in leftBlock.indices) {
totalSum += leftBlock[j] * rightBlock[j]
}
}

destPointer.setAndIncrement(tempSum.sum() * alpha)
tempSum.fill(PrimitiveConstants.ZERO)
destPointer.setAndIncrement((totalSum * alpha).toPrimitive())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ internal suspend fun dotParallelN(left: PrimitiveNDArray, right: PrimitiveNDArra
for (i in nStart until nEnd) {
val leftBlockOffset = i * lBlocksInRow
val destBlockOffset = i * rdBlocksInRow
val rightBlockIterator = rightBlocks.iterator()
var rightBlockIndex = 0

for (lCol in 0 until lBlocksInRow) {
val leftBlock = leftBlocks[leftBlockOffset + lCol]
Expand All @@ -41,7 +41,7 @@ internal suspend fun dotParallelN(left: PrimitiveNDArray, right: PrimitiveNDArra

for (rdCol in 0 until rdBlocksInRow) {
val destBlock = destBlocks[destBlockOffset + rdCol]
val rightBlock = rightBlockIterator.next()
val rightBlock = rightBlocks[rightBlockIndex++]

for (j in destBlock.indices) {
destBlock[j] = (destBlock[j] + temp * rightBlock[j]).toPrimitive()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ package io.kinference.ndarray.extensions.gelu

import io.kinference.ndarray.*
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.tiled.PrimitiveTiledArray
import io.kinference.ndarray.arrays.memory.contexts.AutoAllocatorContext
import io.kinference.ndarray.arrays.memory.storage.*
import io.kinference.ndarray.extensions.constants.PrimitiveConstants
import io.kinference.ndarray.stubs.absoluteValue
import io.kinference.ndarray.stubs.pow
import io.kinference.ndarray.math.*
import io.kinference.primitives.annotations.GenerateNameFromPrimitives
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.*
import kotlin.coroutines.coroutineContext
import kotlin.math.*

@GenerateNameFromPrimitives
Expand All @@ -22,12 +24,19 @@ internal suspend fun computeGeluPrimitive(input: PrimitiveNDArray, bias: Primiti

val blockSize = input.array.blockSize

val coroutineCount = countCoroutinesByData(blockSize, inputBlocks.size, 2048)
val temporaryBlocks = coroutineContext[AutoAllocatorContext]?.getPrimitiveBlock(coroutineCount, blockSize)
?: Array(coroutineCount) { PrimitiveArray(blockSize) }
val temporaryBlocksAbs = coroutineContext[AutoAllocatorContext]?.getPrimitiveBlock(coroutineCount, blockSize)
?: Array(coroutineCount) { PrimitiveArray(blockSize) }


// Constant 2048 was precomputed on M1 Max processor
// With this constant two launches work faster than single thread without launches
// TODO: (cupertank) Remove constants
parallelizeByBlocks(blockSize, inputBlocks.size, 2048) { blockStart, blockEnd, coroutineIndex ->
val temporaryBlock = PrimitiveArray(blockSize)
val temporaryBlockAbs = PrimitiveArray(blockSize)
val temporaryBlock = temporaryBlocks[coroutineIndex]
val temporaryBlockAbs = temporaryBlocksAbs[coroutineIndex]

for (blockIdx in blockStart until blockEnd) {
val outputBlock = outputBlocks[blockIdx]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ package io.kinference.ndarray.extensions.gelu
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.MutablePrimitiveNDArray
import io.kinference.ndarray.arrays.PrimitiveNDArray
import io.kinference.ndarray.arrays.memory.contexts.AutoAllocatorContext
import io.kinference.ndarray.arrays.memory.storage.*
import io.kinference.ndarray.arrays.tiled.PrimitiveTiledArray
import io.kinference.ndarray.countCoroutinesByData
import io.kinference.ndarray.parallelizeByBlocks
Expand All @@ -16,6 +18,7 @@ import io.kinference.ndarray.math.FastMath
import io.kinference.ndarray.math.exp
import io.kinference.primitives.annotations.GenerateNameFromPrimitives
import io.kinference.primitives.annotations.GeneratePrimitives
import kotlin.coroutines.coroutineContext
import kotlin.math.*

@GenerateNameFromPrimitives
Expand All @@ -27,11 +30,15 @@ internal suspend fun fastGeluPrimitive(input: PrimitiveNDArray, bias: PrimitiveN

val blockSize = input.array.blockSize

val coroutineCount = countCoroutinesByData(blockSize, inputBlocks.size, 2048)
val temporaryBlocksExp = coroutineContext[AutoAllocatorContext]?.getPrimitiveBlock(coroutineCount, blockSize)
?: Array(coroutineCount) { PrimitiveArray(blockSize) }

// Constant 2048 was precomputed on M1 Max processor
// With this constant two launches work faster than single thread without launches
// TODO: (cupertank) Remove constants
parallelizeByBlocks(blockSize, inputBlocks.size, 2048) { blockStart, blockEnd, _ ->
val temporaryBlockExp = PrimitiveArray(blockSize)
parallelizeByBlocks(blockSize, inputBlocks.size, 2048) { blockStart, blockEnd, coroutineIndex ->
val temporaryBlockExp = temporaryBlocksExp[coroutineIndex]
for (blockIdx in blockStart until blockEnd) {
val outputBlock = outputBlocks[blockIdx]
val block = inputBlocks[blockIdx]
Expand Down
Loading