Skip to content

Commit

Permalink
Add precision only in Schema
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia committed Oct 25, 2024
1 parent 1c379f5 commit 7fcc48e
Showing 1 changed file with 44 additions and 15 deletions.
59 changes: 44 additions & 15 deletions java/src/main/java/ai/rapids/cudf/Schema.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,42 @@
public class Schema {
public static final Schema INFERRED = new Schema();

private static final int UNKNOWN_PRECISION = -1;

private final DType topLevelType;
private final int precision; // storing precision for decimal types
private final List<String> childNames;
private final List<Schema> childSchemas;
private boolean flattened = false;
private String[] flattenedNames;
private DType[] flattenedTypes;
private int[] flattenedCounts;
private int[] flattenedPrecisions;



private Schema(DType topLevelType,
int precision,
List<String> childNames,
List<Schema> childSchemas) {
this.topLevelType = topLevelType;
this.precision = precision;
this.childNames = childNames;
this.childSchemas = childSchemas;
}

private Schema(DType topLevelType,
List<String> childNames,
List<Schema> childSchemas) {
this(topLevelType, UNKNOWN_PRECISION, childNames, childSchemas);
}

/**
* Inferred schema.
*/
private Schema() {
topLevelType = null;
precision = UNKNOWN_PRECISION;
childNames = null;
childSchemas = null;
}
Expand Down Expand Up @@ -105,14 +120,17 @@ private void flattenIfNeeded() {
flattenedNames = null;
flattenedTypes = null;
flattenedCounts = null;
flattenedPrecisions = null;
} else {
String[] names = new String[flatLen];
DType[] types = new DType[flatLen];
int[] counts = new int[flatLen];
collectFlattened(names, types, counts, 0);
int[] precisions = new int[flatLen];
collectFlattened(names, types, counts, precisions, 0);
flattenedNames = names;
flattenedTypes = types;
flattenedCounts = counts;
flattenedPrecisions = precisions;
}
flattened = true;
}
Expand All @@ -128,19 +146,20 @@ private int flattenedLength(int startingLength) {
return startingLength;
}

private int collectFlattened(String[] names, DType[] types, int[] counts, int offset) {
private int collectFlattened(String[] names, DType[] types, int[] counts, int[] precisions, int offset) {
if (childSchemas != null) {
for (int i = 0; i < childSchemas.size(); i++) {
Schema child = childSchemas.get(i);
names[offset] = childNames.get(i);
types[offset] = child.topLevelType;
precisions[offset] = child.precision;
if (child.childNames != null) {
counts[offset] = child.childNames.size();
} else {
counts[offset] = 0;
}
offset++;
offset = this.childSchemas.get(i).collectFlattened(names, types, counts, offset);
offset = this.childSchemas.get(i).collectFlattened(names, types, counts, precisions, offset);
}
}
return offset;
Expand Down Expand Up @@ -233,14 +252,7 @@ public int[] getFlattenedTypeScales() {
*/
public int[] getFlattenedDecimalPrecisions() {
flattenIfNeeded();
if (flattenedTypes == null) {
return null;
}
int[] ret = new int[flattenedTypes.length];
for (int i = 0; i < flattenedTypes.length; i++) {
ret[i] = flattenedTypes[i].getDecimalPrecision();
}
return ret;
return flattenedPrecisions;
}

/**
Expand Down Expand Up @@ -324,11 +336,13 @@ public HostColumnVector.DataType asHostDataType() {

public static class Builder {
private final DType topLevelType;
private final int topLevelPrecision;
private final List<String> names;
private final List<Builder> types;

private Builder(DType topLevelType) {
private Builder(DType topLevelType, int precision) {
this.topLevelType = topLevelType;
this.topLevelPrecision = precision;
if (topLevelType == DType.STRUCT || topLevelType == DType.LIST) {
// There can be children
names = new ArrayList<>();
Expand All @@ -339,14 +353,19 @@ private Builder(DType topLevelType) {
}
}

private Builder(DType topLevelType) {
this(topLevelType, UNKNOWN_PRECISION);
}

/**
* Add a new column
* @param type the type of column to add
* @param name the name of the column to add (Ignored for list types)
* @param precision the decimal precision, only applicable for decimal types
* @return the builder for the new column. This should really only be used when the type
* passed in is a LIST or a STRUCT.
*/
public Builder addColumn(DType type, String name) {
public Builder addColumn(DType type, String name, int precision) {
if (names == null) {
throw new IllegalStateException("A column of type " + topLevelType +
" cannot have children");
Expand All @@ -357,21 +376,31 @@ public Builder addColumn(DType type, String name) {
if (names.contains(name)) {
throw new IllegalStateException("Cannot add duplicate names to a schema");
}
Builder ret = new Builder(type);
Builder ret = new Builder(type, precision);
types.add(ret);
names.add(name);
return ret;
}

public Builder addColumn(DType type, String name) {
return addColumn(type, name, UNKNOWN_PRECISION);
}

/**
* Adds a single column to the current schema. addColumn is preferred as it can be used
* to support nested types.
* @param type the type of the column.
* @param name the name of the column.
* @param precision the decimal precision, only applicable for decimal types.
* @return this for chaining.
*/
public Builder column(DType type, String name, int precision) {
addColumn(type, name, precision);
return this;
}

public Builder column(DType type, String name) {
addColumn(type, name);
addColumn(type, name, UNKNOWN_PRECISION);
return this;
}

Expand Down

0 comments on commit 7fcc48e

Please sign in to comment.