Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Statistics fixes #937

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.math.cumSum
import org.jetbrains.kotlinx.dataframe.math.defaultCumSumSkipNA
import org.jetbrains.kotlinx.dataframe.typeClass
import java.math.BigDecimal
import java.math.BigInteger
import kotlin.reflect.KProperty
import kotlin.reflect.typeOf

Expand All @@ -22,15 +23,30 @@ public fun <T : Number?> DataColumn<T>.cumSum(skipNA: Boolean = defaultCumSumSki
typeOf<Float?>() -> cast<Float?>().cumSum(skipNA).cast()
typeOf<Int>() -> cast<Int>().cumSum().cast()
typeOf<Int?>() -> cast<Int?>().cumSum(skipNA).cast()
typeOf<Byte>() -> cast<Byte>().cumSum().cast()
typeOf<Byte?>() -> cast<Byte?>().cumSum(skipNA).cast()
typeOf<Short>() -> cast<Short>().cumSum().cast()
typeOf<Short?>() -> cast<Short?>().cumSum(skipNA).cast()
typeOf<Long>() -> cast<Long>().cumSum().cast()
typeOf<Long?>() -> cast<Long?>().cumSum(skipNA).cast()
typeOf<BigInteger>() -> cast<BigInteger>().cumSum().cast()
typeOf<BigInteger?>() -> cast<BigInteger?>().cumSum(skipNA).cast()
typeOf<BigDecimal>() -> cast<BigDecimal>().cumSum().cast()
typeOf<BigDecimal?>() -> cast<BigDecimal?>().cumSum(skipNA).cast()
typeOf<Number?>(), typeOf<Number>() -> convertToDouble().cumSum(skipNA).cast()
else -> error("Cumsum for type ${type()} is not supported")
}

private val supportedClasses = setOf(Double::class, Float::class, Int::class, Long::class, BigDecimal::class)
private val supportedClasses = setOf(
Double::class,
Float::class,
Int::class,
Byte::class,
Short::class,
Long::class,
BigInteger::class,
BigDecimal::class,
)

// endregion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.api.isNA
import org.jetbrains.kotlinx.dataframe.api.map
import java.math.BigDecimal
import java.math.BigInteger

internal val defaultCumSumSkipNA: Boolean = true

Expand Down Expand Up @@ -88,6 +89,66 @@ internal fun DataColumn<Int?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): Dat
}
}

@JvmName("byteCumsum")
internal fun DataColumn<Byte>.cumSum(): DataColumn<Byte> {
var sum = 0.toByte()
return map {
sum = (sum + it).toByte()
sum
}
}

@JvmName("cumsumByteNullable")
internal fun DataColumn<Byte?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<Byte?> {
var sum = 0.toByte()
var fillNull = false
return map {
when {
it == null -> {
if (!skipNA) fillNull = true
null
}

fillNull -> null

else -> {
sum = (sum + it).toByte()
sum
}
}
}
}

@JvmName("shortCumsum")
internal fun DataColumn<Short>.cumSum(): DataColumn<Short> {
var sum = 0.toShort()
return map {
sum = (sum + it).toShort()
sum
}
}

@JvmName("cumsumShortNullable")
internal fun DataColumn<Short?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<Short?> {
var sum = 0.toShort()
var fillNull = false
return map {
when {
it == null -> {
if (!skipNA) fillNull = true
null
}

fillNull -> null

else -> {
sum = (sum + it).toShort()
sum
}
}
}
}

@JvmName("longCumsum")
internal fun DataColumn<Long>.cumSum(): DataColumn<Long> {
var sum = 0L
Expand Down Expand Up @@ -118,6 +179,36 @@ internal fun DataColumn<Long?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): Da
}
}

@JvmName("bigIntegerCumsum")
internal fun DataColumn<BigInteger>.cumSum(): DataColumn<BigInteger> {
var sum = BigInteger.ZERO
return map {
sum += it
sum
}
}

@JvmName("cumsumBigIntegerNullable")
internal fun DataColumn<BigInteger?>.cumSum(skipNA: Boolean = defaultCumSumSkipNA): DataColumn<BigInteger?> {
var sum = BigInteger.ZERO
var fillNull = false
return map {
when {
it == null -> {
if (!skipNA) fillNull = true
null
}

fillNull -> null

else -> {
sum += it
sum
}
}
}
}

@JvmName("bigDecimalCumsum")
internal fun DataColumn<BigDecimal>.cumSum(): DataColumn<BigDecimal> {
var sum = BigDecimal.ZERO
Expand Down
35 changes: 30 additions & 5 deletions core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/mean.kt
Original file line number Diff line number Diff line change
@@ -1,23 +1,38 @@
package org.jetbrains.kotlinx.dataframe.math

import org.jetbrains.kotlinx.dataframe.api.mean
import org.jetbrains.kotlinx.dataframe.api.skipNA_default
import org.jetbrains.kotlinx.dataframe.impl.renderType
import org.jetbrains.kotlinx.dataframe.util.INTERNAL_MEAN
import org.jetbrains.kotlinx.dataframe.util.MEAN
import org.jetbrains.kotlinx.dataframe.util.SEQUENCE_FLOAT_MEAN
import java.math.BigDecimal
import java.math.BigInteger
import kotlin.reflect.KType
import kotlin.reflect.full.withNullability
import kotlin.reflect.typeOf

@JvmName("meanIterableReified")
@PublishedApi
internal inline fun <reified T : Number> Iterable<T>.mean(skipNA: Boolean = skipNA_default): Double =
mean(typeOf<T>(), skipNA)

@PublishedApi
internal fun <T : Number> Iterable<T>.mean(type: KType, skipNA: Boolean = skipNA_default): Double =
asSequence().mean(type, skipNA)

@JvmName("meanSequenceReified")
internal inline fun <reified T : Number> Sequence<T>.mean(skipNA: Boolean = skipNA_default): Double =
mean(typeOf<T>(), skipNA)

internal fun <T : Number> Sequence<T>.mean(type: KType, skipNA: Boolean = skipNA_default): Double {
if (type.isMarkedNullable) {
return filterNotNull().mean(type.withNullability(false), skipNA)
}
return when (type.classifier) {
Double::class -> (this as Sequence<Double>).mean(skipNA)

Float::class -> (this as Sequence<Float>).mean(skipNA)
Float::class -> (this as Sequence<Float>).map { it.toDouble() }.mean(skipNA)

Int::class -> (this as Sequence<Int>).map { it.toDouble() }.mean(false)

Expand All @@ -28,6 +43,8 @@ internal fun <T : Number> Sequence<T>.mean(type: KType, skipNA: Boolean = skipNA

Long::class -> (this as Sequence<Long>).map { it.toDouble() }.mean(false)

BigInteger::class -> (this as Sequence<BigInteger>).map { it.toDouble() }.mean(false)

BigDecimal::class -> (this as Sequence<BigDecimal>).map { it.toDouble() }.mean(skipNA)

Number::class -> (this as Sequence<Number>).map { it.toDouble() }.mean(skipNA)
Expand All @@ -39,7 +56,7 @@ internal fun <T : Number> Sequence<T>.mean(type: KType, skipNA: Boolean = skipNA
}
}

public fun Sequence<Double>.mean(skipNA: Boolean = skipNA_default): Double {
private fun Sequence<Double>.mean(skipNA: Boolean = skipNA_default): Double {
var count = 0
var sum: Double = 0.toDouble()
for (element in this) {
Expand All @@ -56,8 +73,9 @@ public fun Sequence<Double>.mean(skipNA: Boolean = skipNA_default): Double {
return if (count > 0) sum / count else Double.NaN
}

@Deprecated(SEQUENCE_FLOAT_MEAN, level = DeprecationLevel.ERROR)
@JvmName("meanFloat")
public fun Sequence<Float>.mean(skipNA: Boolean = skipNA_default): Double {
internal fun Sequence<Float>.mean(skipNA: Boolean = skipNA_default): Double {
var count = 0
var sum: Double = 0.toDouble()
for (element in this) {
Expand All @@ -74,12 +92,15 @@ public fun Sequence<Float>.mean(skipNA: Boolean = skipNA_default): Double {
return if (count > 0) sum / count else Double.NaN
}

@Deprecated(INTERNAL_MEAN, level = DeprecationLevel.HIDDEN)
@JvmName("doubleMean")
public fun Iterable<Double>.mean(skipNA: Boolean = skipNA_default): Double = asSequence().mean(skipNA)
internal fun Iterable<Double>.mean(skipNA: Boolean = skipNA_default): Double = mean(typeOf<Double>(), skipNA)

@Deprecated(INTERNAL_MEAN, level = DeprecationLevel.HIDDEN)
@JvmName("floatMean")
public fun Iterable<Float>.mean(skipNA: Boolean = skipNA_default): Double = asSequence().mean(skipNA)
internal fun Iterable<Float>.mean(skipNA: Boolean = skipNA_default): Double = mean(typeOf<Float>(), skipNA)

@Deprecated(MEAN, level = DeprecationLevel.HIDDEN)
@JvmName("intMean")
public fun Iterable<Int>.mean(): Double =
if (this is Collection) {
Expand All @@ -93,6 +114,7 @@ public fun Iterable<Int>.mean(): Double =
if (count > 0) sum / count else Double.NaN
}

@Deprecated(MEAN, level = DeprecationLevel.HIDDEN)
@JvmName("shortMean")
public fun Iterable<Short>.mean(): Double =
if (this is Collection) {
Expand All @@ -106,6 +128,7 @@ public fun Iterable<Short>.mean(): Double =
if (count > 0) sum / count else Double.NaN
}

@Deprecated(MEAN, level = DeprecationLevel.HIDDEN)
@JvmName("byteMean")
public fun Iterable<Byte>.mean(): Double =
if (this is Collection) {
Expand All @@ -119,6 +142,7 @@ public fun Iterable<Byte>.mean(): Double =
if (count > 0) sum / count else Double.NaN
}

@Deprecated(MEAN, level = DeprecationLevel.HIDDEN)
@JvmName("longMean")
public fun Iterable<Long>.mean(): Double =
if (this is Collection) {
Expand All @@ -132,6 +156,7 @@ public fun Iterable<Long>.mean(): Double =
if (count > 0) sum / count else Double.NaN
}

@Deprecated(MEAN, level = DeprecationLevel.HIDDEN)
@JvmName("bigDecimalMean")
public fun Iterable<BigDecimal>.mean(): Double =
if (this is Collection) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ package org.jetbrains.kotlinx.dataframe.util
* After each release, all messages should be reviewed and updated.
* Level.WARNING -> Level.ERROR
* Level.ERROR -> Remove
*
* Level.HIDDEN can remain as is but needs to be removed together with other deprecations in
* the same cycle.
*/

// region WARNING in 0.15, ERROR in 0.16

private const val MESSAGE_0_16 = "Will be removed in 0.16."
private const val MESSAGE_0_16 = "Will be ERROR in 0.16."

internal const val DF_READ_NO_CSV = "This function is deprecated and should be replaced with `readCSV`. $MESSAGE_0_16"
internal const val DF_READ_NO_CSV_REPLACE =
Expand Down Expand Up @@ -44,11 +47,20 @@ internal const val PARSER_OPTIONS = "This constructor is only here for binary co

internal const val PARSER_OPTIONS_COPY = "This function is only here for binary compatibility. $MESSAGE_0_16"

internal const val SEQUENCE_FLOAT_MEAN =
"`Sequence<Float>.mean()` is removed since it's already covered by other overloads. $MESSAGE_0_16"

internal const val INTERNAL_MEAN =
"`Iterable.mean(skipNA)` is removed since it's already covered by other overloads. $MESSAGE_0_16"

internal const val MEAN =
"`Iterable.mean()` is removed from the public API because it's outside the scope of DataFrame. You can still call `.mean()` on a column. For most types there's already `.average()` in stdlib. $MESSAGE_0_16"

// endregion

// region WARNING in 0.16, ERROR in 0.17

private const val MESSAGE_0_17 = "Will be removed in 0.17."
private const val MESSAGE_0_17 = "Will be ERROR in 0.17."

// endregion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.dataframe.api.concat
import org.jetbrains.kotlinx.dataframe.api.cumSum
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
import org.jetbrains.kotlinx.dataframe.api.groupBy
import org.jetbrains.kotlinx.dataframe.api.map
import org.junit.Test

@Suppress("ktlint:standard:argument-list-wrapping")
Expand All @@ -22,6 +23,30 @@ class CumsumTests {
col.cumSum(skipNA = false).toList() shouldBe expectedNoSkip
}

@Test
fun `short column`() {
col.map { it?.toShort() }.cumSum().toList() shouldBe expected.map { it?.toShort() }
col.map { it?.toShort() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toShort() }
}

@Test
fun `byte column`() {
col.map { it?.toByte() }.cumSum().toList() shouldBe expected.map { it?.toByte() }
col.map { it?.toByte() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toByte() }
}

@Test
fun `big int column`() {
col.map { it?.toBigInteger() }.cumSum().toList() shouldBe expected.map { it?.toBigInteger() }
col.map { it?.toBigInteger() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toBigInteger() }
}

@Test
fun `big decimal column`() {
col.map { it?.toBigDecimal() }.cumSum().toList() shouldBe expected.map { it?.toBigDecimal() }
col.map { it?.toBigDecimal() }.cumSum(skipNA = false).toList() shouldBe expectedNoSkip.map { it?.toBigDecimal() }
}

@Test
fun frame() {
val str by columnOf("a", "b", "c", "d", "e")
Expand Down