Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform conversion for the columns output from Table.readJSON to other data types using JSONUtils.convertDataTypes() #11618

Draft
wants to merge 11 commits into
base: branch-24.12
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
package org.apache.spark.sql.rapids

import java.util.Locale

import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, NvtxColor, NvtxRange, Scalar, Schema, Table}
//import ai.rapids.cudf.{ColumnVector, ColumnView, DType, NvtxColor, NvtxRange, Scalar, Schema, Table, TableDebug}
import ai.rapids.cudf.{ColumnVector, ColumnView, DType, NvtxColor, NvtxRange, Schema, Table}
import com.fasterxml.jackson.core.JsonParser
import com.nvidia.spark.rapids.{ColumnCastUtil, GpuCast, GpuColumnVector, GpuScalar, GpuTextBasedPartitionReader}
import com.nvidia.spark.rapids.{ColumnCastUtil, GpuColumnVector, GpuScalar, GpuTextBasedPartitionReader}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingArray
import com.nvidia.spark.rapids.jni.CastStrings
import com.nvidia.spark.rapids.jni.{JSONUtils}

import org.apache.spark.sql.catalyst.json.{GpuJsonUtils, JSONOptions}
import org.apache.spark.sql.rapids.shims.GpuJsonToStructsShim
Expand Down Expand Up @@ -62,159 +62,8 @@ object GpuJsonReadCommon {
builder.build
}

private def isQuotedString(input: ColumnView): ColumnVector = {
withResource(Scalar.fromString("\"")) { quote =>
withResource(input.startsWith(quote)) { sw =>
withResource(input.endsWith(quote)) { ew =>
sw.binaryOp(BinaryOp.LOGICAL_AND, ew, DType.BOOL8)
}
}
}
}

private def stripFirstAndLastChar(input: ColumnView): ColumnVector = {
withResource(Scalar.fromInt(1)) { one =>
val end = withResource(input.getCharLengths) { cc =>
withResource(cc.sub(one)) { endWithNulls =>
withResource(endWithNulls.isNull) { eIsNull =>
eIsNull.ifElse(one, endWithNulls)
}
}
}
withResource(end) { _ =>
withResource(ColumnVector.fromScalar(one, end.getRowCount.toInt)) { start =>
input.substring(start, end)
}
}
}
}

private def undoKeepQuotes(input: ColumnView): ColumnVector = {
withResource(isQuotedString(input)) { iq =>
withResource(stripFirstAndLastChar(input)) { stripped =>
iq.ifElse(stripped, input)
}
}
}

private def fixupQuotedStrings(input: ColumnView): ColumnVector = {
withResource(isQuotedString(input)) { iq =>
withResource(stripFirstAndLastChar(input)) { stripped =>
withResource(Scalar.fromString(null)) { ns =>
iq.ifElse(stripped, ns)
}
}
}
}

private lazy val specialUnquotedFloats =
Seq("NaN", "+INF", "-INF", "+Infinity", "Infinity", "-Infinity")
private lazy val specialQuotedFloats = specialUnquotedFloats.map(s => '"'+s+'"')

/**
* JSON has strict rules about valid numeric formats. See https://www.json.org/ for specification.
*
* Spark then has its own rules for supporting NaN and Infinity, which are not
* valid numbers in JSON.
*/
private def sanitizeFloats(input: ColumnView, options: JSONOptions): ColumnVector = {
// Note that this is not 100% consistent with Spark versions prior to Spark 3.3.0
// due to https://issues.apache.org/jira/browse/SPARK-38060
if (options.allowNonNumericNumbers) {
// Need to normalize the quotes to non-quoted to parse properly
withResource(ColumnVector.fromStrings(specialQuotedFloats: _*)) { quoted =>
withResource(ColumnVector.fromStrings(specialUnquotedFloats: _*)) { unquoted =>
input.findAndReplaceAll(quoted, unquoted)
}
}
} else {
input.copyToColumnVector()
}
}

private def sanitizeInts(input: ColumnView): ColumnVector = {
// Integer numbers cannot look like a float, so no `.` or e The rest of the parsing should
// handle this correctly. The rest of the validation is in CUDF itself

val tmp = withResource(Scalar.fromString(".")) { dot =>
withResource(input.stringContains(dot)) { hasDot =>
withResource(Scalar.fromString("e")) { e =>
withResource(input.stringContains(e)) { hase =>
hasDot.or(hase)
}
}
}
}
val invalid = withResource(tmp) { _ =>
withResource(Scalar.fromString("E")) { E =>
withResource(input.stringContains(E)) { hasE =>
tmp.or(hasE)
}
}
}
withResource(invalid) { _ =>
withResource(Scalar.fromNull(DType.STRING)) { nullString =>
invalid.ifElse(nullString, input)
}
}
}

private def sanitizeQuotedDecimalInUSLocale(input: ColumnView): ColumnVector = {
// The US locale is kind of special in that it will remove the , and then parse the
// input normally
withResource(stripFirstAndLastChar(input)) { stripped =>
withResource(Scalar.fromString(",")) { comma =>
withResource(Scalar.fromString("")) { empty =>
stripped.stringReplace(comma, empty)
}
}
}
}

private def sanitizeDecimal(input: ColumnView, options: JSONOptions): ColumnVector = {
assert(options.locale == Locale.US)
withResource(isQuotedString(input)) { isQuoted =>
withResource(sanitizeQuotedDecimalInUSLocale(input)) { quoted =>
isQuoted.ifElse(quoted, input)
}
}
}

private def castStringToFloat(input: ColumnView, dt: DType,
options: JSONOptions): ColumnVector = {
withResource(sanitizeFloats(input, options)) { sanitizedInput =>
CastStrings.toFloat(sanitizedInput, false, dt)
}
}

private def castStringToDecimal(input: ColumnVector, dt: DecimalType): ColumnVector = {
// TODO there is a bug here around 0 https://github.com/NVIDIA/spark-rapids/issues/10898
CastStrings.toDecimal(input, false, false, dt.precision, -dt.scale)
}

private def castJsonStringToBool(input: ColumnView): ColumnVector = {
// Sadly there is no good kernel right now to do just this check/conversion
val isTrue = withResource(Scalar.fromString("true")) { trueStr =>
input.equalTo(trueStr)
}
withResource(isTrue) { _ =>
val isFalse = withResource(Scalar.fromString("false")) { falseStr =>
input.equalTo(falseStr)
}
val falseOrNull = withResource(isFalse) { _ =>
withResource(Scalar.fromBool(false)) { falseLit =>
withResource(Scalar.fromNull(DType.BOOL8)) { nul =>
isFalse.ifElse(falseLit, nul)
}
}
}
withResource(falseOrNull) { _ =>
withResource(Scalar.fromBool(true)) { trueLit =>
isTrue.ifElse(trueLit, falseOrNull)
}
}
}
}

private def dateFormat(options: JSONOptions): Option[String] =
GpuJsonUtils.optionalDateFormatInRead(options)
Expand Down Expand Up @@ -269,34 +118,58 @@ object GpuJsonReadCommon {
options: JSONOptions): ColumnVector = {
ColumnCastUtil.deepTransform(inputCv, Some(topLevelType),
Some(nestedColumnViewMismatchTransform)) {

//
// DONE
case (cv, Some(BooleanType)) if cv.getType == DType.STRING =>
castJsonStringToBool(cv)
JSONUtils.castStringsToBooleans(cv)
//
//

case (cv, Some(DateType)) if cv.getType == DType.STRING =>
withResource(fixupQuotedStrings(cv)) { fixed =>
withResource(JSONUtils.removeQuotes(cv, true)) { fixed =>
GpuJsonToStructsShim.castJsonStringToDateFromScan(fixed, DType.TIMESTAMP_DAYS,
dateFormat(options))
}
case (cv, Some(TimestampType)) if cv.getType == DType.STRING =>
withResource(fixupQuotedStrings(cv)) { fixed =>
withResource(JSONUtils.removeQuotes(cv, true)) { fixed =>
GpuTextBasedPartitionReader.castStringToTimestamp(fixed, timestampFormat(options),
DType.TIMESTAMP_MICROSECONDS)
}

//
// Done
case (cv, Some(StringType)) if cv.getType == DType.STRING =>
undoKeepQuotes(cv)
JSONUtils.removeQuotes(cv, false)
//
//

//
// Done
case (cv, Some(dt: DecimalType)) if cv.getType == DType.STRING =>
withResource(sanitizeDecimal(cv, options)) { tmp =>
castStringToDecimal(tmp, dt)
}
JSONUtils.castStringsToDecimals(cv, dt.precision, -dt.scale, options.locale == Locale.US)
//
//

//
// DONE
case (cv, Some(dt)) if (dt == DoubleType || dt == FloatType) && cv.getType == DType.STRING =>
castStringToFloat(cv, GpuColumnVector.getNonNestedRapidsType(dt), options)
JSONUtils.castStringsToFloats(cv, GpuColumnVector.getNonNestedRapidsType(dt),
options.allowNonNumericNumbers)
//
//

//
// DONE
case (cv, Some(dt))
if (dt == ByteType || dt == ShortType || dt == IntegerType || dt == LongType ) &&
cv.getType == DType.STRING =>
withResource(sanitizeInts(cv)) { tmp =>
CastStrings.toInteger(tmp, false, GpuColumnVector.getNonNestedRapidsType(dt))
}
JSONUtils.castStringsToIntegers(cv, GpuColumnVector.getNonNestedRapidsType(dt))
//
//

case (cv, Some(dt)) if cv.getType == DType.STRING =>
GpuCast.doCast(cv, StringType, dt)
throw new JsonParsingException(s"Cannot convert string to $dt", null)
}
}

Expand Down
Loading