diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cumSum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cumSum.kt index 9d8c7ff71..aec6f276f 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cumSum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/cumSum.kt @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.math.cumSum import org.jetbrains.kotlinx.dataframe.math.defaultCumSumSkipNA import org.jetbrains.kotlinx.dataframe.typeClass import java.math.BigDecimal +import java.math.BigInteger import kotlin.reflect.KProperty import kotlin.reflect.typeOf @@ -22,15 +23,30 @@ public fun DataColumn.cumSum(skipNA: Boolean = defaultCumSumSki typeOf() -> cast().cumSum(skipNA).cast() typeOf() -> cast().cumSum().cast() typeOf() -> cast().cumSum(skipNA).cast() + typeOf() -> cast().cumSum().cast() + typeOf() -> cast().cumSum(skipNA).cast() + typeOf() -> cast().cumSum().cast() + typeOf() -> cast().cumSum(skipNA).cast() typeOf() -> cast().cumSum().cast() typeOf() -> cast().cumSum(skipNA).cast() + typeOf() -> cast().cumSum().cast() + typeOf() -> cast().cumSum(skipNA).cast() typeOf() -> cast().cumSum().cast() typeOf() -> cast().cumSum(skipNA).cast() typeOf(), typeOf() -> convertToDouble().cumSum(skipNA).cast() else -> error("Cumsum for type ${type()} is not supported") } -private val supportedClasses = setOf(Double::class, Float::class, Int::class, Long::class, BigDecimal::class) +private val supportedClasses = setOf( + Double::class, + Float::class, + Int::class, + Byte::class, + Short::class, + Long::class, + BigInteger::class, + BigDecimal::class, +) // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/cumsum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/cumsum.kt index 58ef59329..5efb7ff97 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/cumsum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/cumsum.kt @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.api.isNA import org.jetbrains.kotlinx.dataframe.api.map import java.math.BigDecimal +import java.math.BigInteger internal val defaultCumSumSkipNA: Boolean = true @@ -88,6 +89,66 @@ internal fun DataColumn.cumSum(skipNA: Boolean = defaultCumSumSkipNA): Dat } } +@JvmName("byteCumsum") +internal fun DataColumn.cumSum(): DataColumn { + var sum = 0.toByte() + return map { + sum = (sum + it).toByte() + sum + } +} + +@JvmName("cumsumByteNullable") +internal fun DataColumn.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn { + var sum = 0.toByte() + var fillNull = false + return map { + when { + it == null -> { + if (!skipNA) fillNull = true + null + } + + fillNull -> null + + else -> { + sum = (sum + it).toByte() + sum + } + } + } +} + +@JvmName("shortCumsum") +internal fun DataColumn.cumSum(): DataColumn { + var sum = 0.toShort() + return map { + sum = (sum + it).toShort() + sum + } +} + +@JvmName("cumsumShortNullable") +internal fun DataColumn.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn { + var sum = 0.toShort() + var fillNull = false + return map { + when { + it == null -> { + if (!skipNA) fillNull = true + null + } + + fillNull -> null + + else -> { + sum = (sum + it).toShort() + sum + } + } + } +} + @JvmName("longCumsum") internal fun DataColumn.cumSum(): DataColumn { var sum = 0L @@ -118,6 +179,36 @@ internal fun DataColumn.cumSum(skipNA: Boolean = defaultCumSumSkipNA): Da } } +@JvmName("bigIntegerCumsum") +internal fun DataColumn.cumSum(): DataColumn { + var sum = BigInteger.ZERO + return map { + sum += it + sum + } +} + +@JvmName("cumsumBigIntegerNullable") +internal fun DataColumn.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn { + var sum = BigInteger.ZERO + var fillNull = false + return map { + when { + it == null -> { + if (!skipNA) fillNull = true + null + } + + fillNull -> null + + else -> { + sum += it + sum + } + } + } +} + @JvmName("bigDecimalCumsum") internal fun DataColumn.cumSum(): DataColumn { var sum = BigDecimal.ZERO diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt index c7c6e9596..4ce9e68f9 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt @@ -1,15 +1,30 @@ package org.jetbrains.kotlinx.dataframe.math +import org.jetbrains.kotlinx.dataframe.api.mean import org.jetbrains.kotlinx.dataframe.api.skipNA_default import org.jetbrains.kotlinx.dataframe.impl.renderType +import org.jetbrains.kotlinx.dataframe.util.INTERNAL_MEAN +import org.jetbrains.kotlinx.dataframe.util.MEAN +import org.jetbrains.kotlinx.dataframe.util.SEQUENCE_FLOAT_MEAN import java.math.BigDecimal +import java.math.BigInteger import kotlin.reflect.KType import kotlin.reflect.full.withNullability +import kotlin.reflect.typeOf + +@JvmName("meanIterableReified") +@PublishedApi +internal inline fun Iterable.mean(skipNA: Boolean = skipNA_default): Double = + mean(typeOf(), skipNA) @PublishedApi internal fun Iterable.mean(type: KType, skipNA: Boolean = skipNA_default): Double = asSequence().mean(type, skipNA) +@JvmName("meanSequenceReified") +internal inline fun Sequence.mean(skipNA: Boolean = skipNA_default): Double = + mean(typeOf(), skipNA) + internal fun Sequence.mean(type: KType, skipNA: Boolean = skipNA_default): Double { if (type.isMarkedNullable) { return filterNotNull().mean(type.withNullability(false), skipNA) @@ -17,7 +32,7 @@ internal fun Sequence.mean(type: KType, skipNA: Boolean = skipNA return when (type.classifier) { Double::class -> (this as Sequence).mean(skipNA) - Float::class -> (this as Sequence).mean(skipNA) + Float::class -> (this as Sequence).map { it.toDouble() }.mean(skipNA) Int::class -> (this as Sequence).map { it.toDouble() }.mean(false) @@ -28,6 +43,8 @@ internal fun Sequence.mean(type: KType, skipNA: Boolean = skipNA Long::class -> (this as Sequence).map { it.toDouble() }.mean(false) + BigInteger::class -> (this as Sequence).map { it.toDouble() }.mean(false) + BigDecimal::class -> (this as Sequence).map { it.toDouble() }.mean(skipNA) Number::class -> (this as Sequence).map { it.toDouble() }.mean(skipNA) @@ -39,7 +56,7 @@ internal fun Sequence.mean(type: KType, skipNA: Boolean = skipNA } } -public fun Sequence.mean(skipNA: Boolean = skipNA_default): Double { +private fun Sequence.mean(skipNA: Boolean = skipNA_default): Double { var count = 0 var sum: Double = 0.toDouble() for (element in this) { @@ -56,8 +73,9 @@ public fun Sequence.mean(skipNA: Boolean = skipNA_default): Double { return if (count > 0) sum / count else Double.NaN } +@Deprecated(SEQUENCE_FLOAT_MEAN, level = DeprecationLevel.ERROR) @JvmName("meanFloat") -public fun Sequence.mean(skipNA: Boolean = skipNA_default): Double { +internal fun Sequence.mean(skipNA: Boolean = skipNA_default): Double { var count = 0 var sum: Double = 0.toDouble() for (element in this) { @@ -74,12 +92,15 @@ public fun Sequence.mean(skipNA: Boolean = skipNA_default): Double { return if (count > 0) sum / count else Double.NaN } +@Deprecated(INTERNAL_MEAN, level = DeprecationLevel.HIDDEN) @JvmName("doubleMean") -public fun Iterable.mean(skipNA: Boolean = skipNA_default): Double = asSequence().mean(skipNA) +internal fun Iterable.mean(skipNA: Boolean = skipNA_default): Double = mean(typeOf(), skipNA) +@Deprecated(INTERNAL_MEAN, level = DeprecationLevel.HIDDEN) @JvmName("floatMean") -public fun Iterable.mean(skipNA: Boolean = skipNA_default): Double = asSequence().mean(skipNA) +internal fun Iterable.mean(skipNA: Boolean = skipNA_default): Double = mean(typeOf(), skipNA) +@Deprecated(MEAN, level = DeprecationLevel.HIDDEN) @JvmName("intMean") public fun Iterable.mean(): Double = if (this is Collection) { @@ -93,6 +114,7 @@ public fun Iterable.mean(): Double = if (count > 0) sum / count else Double.NaN } +@Deprecated(MEAN, level = DeprecationLevel.HIDDEN) @JvmName("shortMean") public fun Iterable.mean(): Double = if (this is Collection) { @@ -106,6 +128,7 @@ public fun Iterable.mean(): Double = if (count > 0) sum / count else Double.NaN } +@Deprecated(MEAN, level = DeprecationLevel.HIDDEN) @JvmName("byteMean") public fun Iterable.mean(): Double = if (this is Collection) { @@ -119,6 +142,7 @@ public fun Iterable.mean(): Double = if (count > 0) sum / count else Double.NaN } +@Deprecated(MEAN, level = DeprecationLevel.HIDDEN) @JvmName("longMean") public fun Iterable.mean(): Double = if (this is Collection) { @@ -132,6 +156,7 @@ public fun Iterable.mean(): Double = if (count > 0) sum / count else Double.NaN } +@Deprecated(MEAN, level = DeprecationLevel.HIDDEN) @JvmName("bigDecimalMean") public fun Iterable.mean(): Double = if (this is Collection) { diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/util/deprecationMessages.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/util/deprecationMessages.kt index b25e7fb60..f819b0f9c 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/util/deprecationMessages.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/util/deprecationMessages.kt @@ -5,11 +5,14 @@ package org.jetbrains.kotlinx.dataframe.util * After each release, all messages should be reviewed and updated. * Level.WARNING -> Level.ERROR * Level.ERROR -> Remove + * + * Level.HIDDEN can remain as is but needs to be removed together with other deprecations in + * the same cycle. */ // region WARNING in 0.15, ERROR in 0.16 -private const val MESSAGE_0_16 = "Will be removed in 0.16." +private const val MESSAGE_0_16 = "Will be ERROR in 0.16." internal const val DF_READ_NO_CSV = "This function is deprecated and should be replaced with `readCSV`. $MESSAGE_0_16" internal const val DF_READ_NO_CSV_REPLACE = @@ -44,11 +47,20 @@ internal const val PARSER_OPTIONS = "This constructor is only here for binary co internal const val PARSER_OPTIONS_COPY = "This function is only here for binary compatibility. $MESSAGE_0_16" +internal const val SEQUENCE_FLOAT_MEAN = + "`Sequence.mean()` is removed since it's already covered by other overloads. $MESSAGE_0_16" + +internal const val INTERNAL_MEAN = + "`Iterable.mean(skipNA)` is removed since it's already covered by other overloads. $MESSAGE_0_16" + +internal const val MEAN = + "`Iterable.mean()` is removed from the public API because it's outside the scope of DataFrame. You can still call `.mean()` on a column. For most types there's already `.average()` in stdlib. $MESSAGE_0_16" + // endregion // region WARNING in 0.16, ERROR in 0.17 -private const val MESSAGE_0_17 = "Will be removed in 0.17." +private const val MESSAGE_0_17 = "Will be ERROR in 0.17." // endregion diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/cumsum.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/cumsum.kt index 385023eda..cf42af3fe 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/cumsum.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/cumsum.kt @@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.dataframe.api.concat import org.jetbrains.kotlinx.dataframe.api.cumSum import org.jetbrains.kotlinx.dataframe.api.dataFrameOf import org.jetbrains.kotlinx.dataframe.api.groupBy +import org.jetbrains.kotlinx.dataframe.api.map import org.junit.Test @Suppress("ktlint:standard:argument-list-wrapping") @@ -22,6 +23,30 @@ class CumsumTests { col.cumSum(skipNA = false).toList() shouldBe expectedNoSkip } + @Test + fun `short column`() { + col.map { it?.toShort() }.cumSum().toList() shouldBe expected.map { it?.toShort() } + col.map { it?.toShort() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toShort() } + } + + @Test + fun `byte column`() { + col.map { it?.toByte() }.cumSum().toList() shouldBe expected.map { it?.toByte() } + col.map { it?.toByte() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toByte() } + } + + @Test + fun `big int column`() { + col.map { it?.toBigInteger() }.cumSum().toList() shouldBe expected.map { it?.toBigInteger() } + col.map { it?.toBigInteger() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toBigInteger() } + } + + @Test + fun `big decimal column`() { + col.map { it?.toBigDecimal() }.cumSum().toList() shouldBe expected.map { it?.toBigDecimal() } + col.map { it?.toBigDecimal() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toBigDecimal() } + } + @Test fun frame() { val str by columnOf("a", "b", "c", "d", "e")