diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index 734b4dfb708..2645d1d5f98 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -1118,7 +1118,8 @@ def test_hash_groupby_typed_imperative_agg_without_gpu_implementation_fallback() @disable_ansi_mode # https://github.com/NVIDIA/spark-rapids/issues/5114 @pytest.mark.parametrize('data_gen', _init_list, ids=idfn) @pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn) -def test_hash_multiple_mode_query(data_gen, conf): +@pytest.mark.parametrize('shuffle_split', [True, False], ids=idfn) +def test_hash_multiple_mode_query(data_gen, conf, shuffle_split): print_params(data_gen) assert_gpu_and_cpu_are_equal_collect( lambda spark: gen_df(spark, data_gen, length=100) @@ -1132,7 +1133,10 @@ def test_hash_multiple_mode_query(data_gen, conf): f.max('a'), f.sumDistinct('b'), f.countDistinct('c') - ), conf=conf) + ), + conf=copy_and_update( + conf, + {'spark.rapids.shuffle.splitRetryRead.enabled': shuffle_split})) @approximate_float diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 703fbe80230..d23d74da524 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -242,10 +242,12 @@ def test_hash_join_ridealong_non_sized(data_gen, join_type, sub_part_enabled): @ignore_order(local=True) @pytest.mark.parametrize('data_gen', basic_nested_gens + [decimal_gen_128bit], ids=idfn) @pytest.mark.parametrize('join_type', all_symmetric_sized_join_types, ids=idfn) +@pytest.mark.parametrize('shuffle_split', [True, False], ids=idfn) @allow_non_gpu(*non_utc_allow) -def test_hash_join_ridealong_symmetric(data_gen, join_type): +def test_hash_join_ridealong_symmetric(data_gen, join_type, shuffle_split): confs = { "spark.rapids.sql.join.useShuffledSymmetricHashJoin": "true", + "spark.rapids.shuffle.splitRetryRead.enabled": shuffle_split, } hash_join_ridealong(data_gen, join_type, confs) @@ -253,10 +255,12 @@ def test_hash_join_ridealong_symmetric(data_gen, join_type): @ignore_order(local=True) @pytest.mark.parametrize('data_gen', basic_nested_gens + [decimal_gen_128bit], ids=idfn) @pytest.mark.parametrize('join_type', all_asymmetric_sized_join_types, ids=idfn) +@pytest.mark.parametrize('shuffle_split', [True, False], ids=idfn) @allow_non_gpu(*non_utc_allow) -def test_hash_join_ridealong_asymmetric(data_gen, join_type): +def test_hash_join_ridealong_asymmetric(data_gen, join_type, shuffle_split): confs = { "spark.rapids.sql.join.useShuffledAsymmetricHashJoin": "true", + "spark.rapids.shuffle.splitRetryRead.enabled": shuffle_split, } hash_join_ridealong(data_gen, join_type, confs) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceExec.scala index a88bd9f2cfb..4fc1696113a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,13 +18,19 @@ package com.nvidia.spark.rapids import java.util -import ai.rapids.cudf.{HostMemoryBuffer, JCudfSerialization, NvtxColor, NvtxRange} -import ai.rapids.cudf.JCudfSerialization.{HostConcatResult, SerializedTableHeader} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import scala.collection.mutable +import scala.reflect.ClassTag + +import ai.rapids.cudf.{JCudfSerialization, NvtxColor, NvtxRange} +import ai.rapids.cudf.JCudfSerialization.HostConcatResult +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource, withResourceIfAllowed} +import com.nvidia.spark.rapids.RapidsPluginImplicits.{AutoCloseableProducingSeq, AutoCloseableSeq} +import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitTargetSizeInHalfGpu, withRetry} import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.shims.ShimUnaryExecNode import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute @@ -66,78 +72,339 @@ case class GpuShuffleCoalesceExec(child: SparkPlan, targetBatchByteSize: Long) val metricsMap = allMetrics val targetSize = targetBatchByteSize val dataTypes = GpuColumnVector.extractTypes(schema) + val readOption = CoalesceReadOption(new RapidsConf(conf)) child.executeColumnar().mapPartitions { iter => - new GpuShuffleCoalesceIterator( - new HostShuffleCoalesceIterator(iter, targetSize, metricsMap), - dataTypes, metricsMap) + GpuShuffleCoalesceUtils.getGpuShuffleCoalesceIterator(iter, targetSize, dataTypes, + readOption, metricsMap) } } } +/** A case class to pack some options. Now it has only one, but may have more in the future */ +case class CoalesceReadOption private(useSplitRetryRead: Boolean) + +object CoalesceReadOption { + def apply(conf: RapidsConf): CoalesceReadOption = + CoalesceReadOption(conf.shuffleSplitRetryReadEnabled) +} + +object GpuShuffleCoalesceUtils { + def getGpuShuffleCoalesceIterator( + iter: Iterator[ColumnarBatch], + targetSize: Long, + dataTypes: Array[DataType], + readOption: CoalesceReadOption, + metricsMap: Map[String, GpuMetric], + prefetchFirstBatch: Boolean = false): Iterator[ColumnarBatch] = { + if (readOption.useSplitRetryRead) { + val reader = new GpuShuffleCoalesceReader(iter, targetSize, dataTypes, metricsMap) + if (prefetchFirstBatch) { + withResource(new NvtxRange("fetch first batch", NvtxColor.YELLOW)) { _ => + // Force a coalesce of the first batch before we grab the GPU semaphore + reader.prefetchHeadOnHost() + } + } + reader.asIterator + } else { + val hostIter = new HostShuffleCoalesceIterator(iter, targetSize, metricsMap) + val maybeBufferedIter = if (prefetchFirstBatch) { + val bufferedIter = new CloseableBufferedIterator(hostIter) + withResource(new NvtxRange("fetch first batch", NvtxColor.YELLOW)) { _ => + // Force a coalesce of the first batch before we grab the GPU semaphore + bufferedIter.headOption + } + bufferedIter + } else { + hostIter + } + new GpuShuffleCoalesceIterator(maybeBufferedIter, dataTypes, metricsMap) + } + } + + /** Try to convert a concatenated host buffer to a GPU batch. */ + def toGpuIfAllowed( + table: AnyRef, + dataTypes: Array[DataType]): ColumnarBatch = table match { + case c: HostConcatResult => + cudf_utils.HostConcatResultUtil.getColumnarBatch(c, dataTypes) + case o => + throw new IllegalArgumentException(s"unsupported type: ${o.getClass}") + } + + /** + * Get an iterator that can coalesce the serialized small host batches just + * returned by the Shuffle deserializer. + */ + def getHostShuffleCoalesceIterator( + iter: BufferedIterator[ColumnarBatch], + targetSize: Long, + coalesceMetrics: Map[String, GpuMetric]): Option[Iterator[AutoCloseable]] = { + var retIter: Option[Iterator[AutoCloseable]] = None + if (iter.hasNext && iter.head.numCols() == 1) { + iter.head.column(0) match { + case _: SerializedTableColumn => + retIter = Some(new HostShuffleCoalesceIterator(iter, targetSize, coalesceMetrics)) + case _ => // should be gpu batches + } + } + retIter + } + + /** Get the buffer size of a serialized batch just returned by the Shuffle deserializer */ + def getSerializedBufferSize(cb: ColumnarBatch): Long = { + assert(cb.numCols() == 1) + val hmb = cb.column(0) match { + case serCol: SerializedTableColumn => + serCol.hostBuffer + case o => throw new IllegalStateException(s"unsupported type: ${o.getClass}") + } + if (hmb != null) hmb.getLength else 0L + } + + /** + * Get the buffer size of the coalesced result, it accepts either a concatenated + * host buffer from the Shuffle coalesce exec, or a coalesced GPU batch. + */ + def getCoalescedBufferSize(concated: AnyRef): Long = concated match { + case c: HostConcatResult => c.getTableHeader.getDataLen + case g => GpuColumnVector.getTotalDeviceMemoryUsed(g.asInstanceOf[ColumnarBatch]) + } +} + +/** + * A trait defining some operations on the table T. + * This is used by GpuShuffleCoalesceReaderBase and HostCoalesceIteratorBase to separate the + * table operations from the shuffle read process. + */ +sealed trait TableOperator[T <: AutoCloseable, C] { + def getDataLen(table: T): Long + def getNumRows(table: T): Int + def concatOnHost(tables: Array[T]): C + def toGpu(c: C, dataTypes: Array[DataType]): ColumnarBatch +} + +class CudfTableOperator extends TableOperator[SerializedTableColumn, HostConcatResult] { + override def getDataLen(table: SerializedTableColumn): Long = table.header.getDataLen + override def getNumRows(table: SerializedTableColumn): Int = table.header.getNumRows + + override def concatOnHost(tables: Array[SerializedTableColumn]): HostConcatResult = { + assert(tables.nonEmpty, "no tables to be concatenated") + val numCols = tables.head.header.getNumColumns + if (numCols == 0) { + val totalRowsNum = tables.map(getNumRows).sum + cudf_utils.HostConcatResultUtil.rowsOnlyHostConcatResult(totalRowsNum) + } else { + val (headers, buffers) = tables.map(t => (t.header, t.hostBuffer)).unzip + JCudfSerialization.concatToHostBuffer(headers, buffers) + } + } + + override def toGpu(c: HostConcatResult, dataTypes: Array[DataType]): ColumnarBatch = { + cudf_utils.HostConcatResultUtil.getColumnarBatch(c, dataTypes) + } +} + +/** + * Reader to coalesce columnar batches that are expected to contain only serialized + * tables T from shuffle. The serialized tables within are collected up to the target + * batch size and then concatenated on the host. Next try to send the concatenated + * result to GPU. + * When OOM happens, it will reduce the target size by half, try to concatenate + * half of cached tables and send the result to GPU again. + */ +abstract class GpuShuffleCoalesceReaderBase[T <: AutoCloseable: ClassTag, C]( + iter: Iterator[ColumnarBatch], + targetBatchSize: Long, + dataTypes: Array[DataType], + metricsMap: Map[String, GpuMetric]) extends AutoCloseable with Logging { + private[this] val opTimeMetric = metricsMap(GpuMetric.OP_TIME) + private[this] val concatTimeMetric = metricsMap(GpuMetric.CONCAT_TIME) + private[this] val inputBatchesMetric = metricsMap(GpuMetric.NUM_INPUT_BATCHES) + private[this] val inputRowsMetric = metricsMap(GpuMetric.NUM_INPUT_ROWS) + private[this] val outputBatchesMetric = metricsMap(GpuMetric.NUM_OUTPUT_BATCHES) + private[this] val outputRowsMetric = metricsMap(GpuMetric.NUM_OUTPUT_ROWS) + + private[this] val serializedTables = new mutable.Queue[T] + private[this] var realBatchSize = math.max(targetBatchSize, 1) + private[this] var closed = false + + protected def tableOperator: TableOperator[T, C] + + // Don't install the callback if in a unit test + Option(TaskContext.get()).foreach { tc => + onTaskCompletion(tc)(close()) + } + + override def close(): Unit = if (!closed) { + serializedTables.foreach(_.close()) + serializedTables.clear() + closed = true + } + + private def pullNextBatch(): Boolean = { + if (closed) return false + // Always make sure enough data has been cached for the next batch. + var curCachedSize = serializedTables.map(tableOperator.getDataLen).sum + var curCachedRows = serializedTables.map(tableOperator.getNumRows(_).toLong).sum + while (iter.hasNext && curCachedSize < realBatchSize && curCachedRows < Int.MaxValue) { + closeOnExcept(iter.next()) { batch => + inputBatchesMetric += 1 + inputRowsMetric += batch.numRows() + if (batch.numRows > 0) { + val tableCol = batch.column(0).asInstanceOf[T] + serializedTables.enqueue(tableCol) + curCachedSize += tableOperator.getDataLen(tableCol) + curCachedRows += tableOperator.getNumRows(tableCol) + } else { + batch.close() + } + } + } + serializedTables.nonEmpty + } + + private def collectTablesForNextBatch(targetSize: Long): Array[T] = { + var curSize = 0L + var curRows = 0L + val taken = serializedTables.takeWhile { tableCol => + curSize += tableOperator.getDataLen(tableCol) + curRows += tableOperator.getNumRows(tableCol) + curSize <= targetSize && curRows < Int.MaxValue + } + if (taken.isEmpty) { + // The first batch size is bigger than targetSize, always take it + Array(serializedTables.head) + } else { + taken.toArray + } + } + + private val reduceBatchSizeByHalf: AutoCloseableTargetSize => Seq[AutoCloseableTargetSize] = + batchSize => { + val halfSize = splitTargetSizeInHalfGpu(batchSize) + assert(halfSize.length == 1) + // Remember the size for the following caching and collecting. + logDebug(s"Update target batch size from $realBatchSize to ${halfSize.head.targetSize}") + realBatchSize = halfSize.head.targetSize + halfSize + } + + private def buildNextBatch(): ColumnarBatch = { + val closeableBatchSize = AutoCloseableTargetSize(realBatchSize, 1) + val iter = withRetry(closeableBatchSize, reduceBatchSizeByHalf) { attemptSize => + val (concatRet, numTables) = withResource(new MetricRange(opTimeMetric)) { _ => + // Retry steps: + // 1) Collect tables from cache for the next batch according to the target size. + // 2) Concatenate the collected tables + // 3) Move the concatenated result to GPU + // We have to re-collect the tables and re-concatenate them, because the + // HostConcatResult can not be split into smaller pieces. + val curTables = collectTablesForNextBatch(attemptSize.targetSize) + val concatHostBatch = withResource(new MetricRange(concatTimeMetric)) { _ => + tableOperator.concatOnHost(curTables) + } + (concatHostBatch, curTables.length) + } + withResourceIfAllowed(concatRet) { _ => + // Begin to use GPU + GpuSemaphore.acquireIfNecessary(TaskContext.get()) + withResource(new MetricRange(opTimeMetric)) { _ => + (tableOperator.toGpu(concatRet, dataTypes), numTables) + } + } + } + // Expect only one batch + val (batch, numTables) = iter.next() + closeOnExcept(batch) { _ => + assert(iter.isEmpty) + // Now it is ok to remove the first numTables table from cache. + (0 until numTables).safeMap(_ => serializedTables.dequeue()).safeClose() + batch + } + } + + def prefetchHeadOnHost(): this.type = { + if (serializedTables.isEmpty) { + pullNextBatch() + } + this + } + + def asIterator: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] { + override def hasNext: Boolean = pullNextBatch() + override def next(): ColumnarBatch = { + if (!hasNext) { + throw new NoSuchElementException("No more host batches to read") + } + val batch = buildNextBatch() + outputBatchesMetric += 1 + outputRowsMetric += batch.numRows() + batch + } + } +} + +class GpuShuffleCoalesceReader( + iter: Iterator[ColumnarBatch], + targetBatchSize: Long, + dataTypes: Array[DataType], + metricsMap: Map[String, GpuMetric]) + extends GpuShuffleCoalesceReaderBase[SerializedTableColumn, HostConcatResult](iter, + targetBatchSize, dataTypes, metricsMap) { + + override protected val tableOperator = new CudfTableOperator +} + /** * Iterator that coalesces columnar batches that are expected to only contain - * [[SerializedTableColumn]]. The serialized tables within are collected up + * serialized tables from shuffle. The serialized tables within are collected up * to the target batch size and then concatenated on the host before handing * them to the caller on `.next()` */ -class HostShuffleCoalesceIterator( +abstract class HostCoalesceIteratorBase[T <: AutoCloseable: ClassTag, C]( iter: Iterator[ColumnarBatch], targetBatchByteSize: Long, - metricsMap: Map[String, GpuMetric]) - extends Iterator[HostConcatResult] with AutoCloseable { + metricsMap: Map[String, GpuMetric]) extends Iterator[C] with AutoCloseable { private[this] val concatTimeMetric = metricsMap(GpuMetric.CONCAT_TIME) private[this] val inputBatchesMetric = metricsMap(GpuMetric.NUM_INPUT_BATCHES) private[this] val inputRowsMetric = metricsMap(GpuMetric.NUM_INPUT_ROWS) - private[this] val serializedTables = new util.ArrayDeque[SerializedTableColumn] + private[this] val serializedTables = new util.ArrayDeque[T] private[this] var numTablesInBatch: Int = 0 private[this] var numRowsInBatch: Int = 0 private[this] var batchByteSize: Long = 0L // Don't install the callback if in a unit test Option(TaskContext.get()).foreach { tc => - onTaskCompletion(tc) { - close() - } + onTaskCompletion(tc)(close()) } + protected def tableOperator: TableOperator[T, C] + override def close(): Unit = { serializedTables.forEach(_.close()) serializedTables.clear() } - def concatenateTablesInHost(): HostConcatResult = { + private def concatenateTablesInHost(): C = { val result = withResource(new MetricRange(concatTimeMetric)) { _ => - val firstHeader = serializedTables.peekFirst().header - if (firstHeader.getNumColumns == 0) { - (0 until numTablesInBatch).foreach(_ => serializedTables.removeFirst()) - cudf_utils.HostConcatResultUtil.rowsOnlyHostConcatResult(numRowsInBatch) - } else { - val headers = new Array[SerializedTableHeader](numTablesInBatch) - withResource(new Array[HostMemoryBuffer](numTablesInBatch)) { buffers => - headers.indices.foreach { i => - val serializedTable = serializedTables.removeFirst() - headers(i) = serializedTable.header - buffers(i) = serializedTable.hostBuffer - } - JCudfSerialization.concatToHostBuffer(headers, buffers) - } + withResource(new Array[T](numTablesInBatch)) { tables => + tables.indices.foreach(i => tables(i) = serializedTables.removeFirst()) + tableOperator.concatOnHost(tables) } } // update the stats for the next batch in progress numTablesInBatch = serializedTables.size - batchByteSize = 0 numRowsInBatch = 0 if (numTablesInBatch > 0) { require(numTablesInBatch == 1, "should only track at most one buffer that is not in a batch") - val header = serializedTables.peekFirst().header - batchByteSize = header.getDataLen - numRowsInBatch = header.getNumRows + val firstTable = serializedTables.peekFirst() + batchByteSize = tableOperator.getDataLen(firstTable) + numRowsInBatch = tableOperator.getNumRows(firstTable) } - result } @@ -150,14 +417,14 @@ class HostShuffleCoalesceIterator( // don't bother tracking empty tables if (batch.numRows > 0) { inputRowsMetric += batch.numRows() - val tableColumn = batch.column(0).asInstanceOf[SerializedTableColumn] - batchCanGrow = canAddToBatch(tableColumn.header) + val tableColumn = batch.column(0).asInstanceOf[T] + batchCanGrow = canAddToBatch(tableColumn) serializedTables.addLast(tableColumn) // always add the first table to the batch even if its beyond the target limits if (batchCanGrow || numTablesInBatch == 0) { numTablesInBatch += 1 - numRowsInBatch += tableColumn.header.getNumRows - batchByteSize += tableColumn.header.getDataLen + numRowsInBatch += tableOperator.getNumRows(tableColumn) + batchByteSize += tableOperator.getDataLen(tableColumn) } } else { batch.close() @@ -172,34 +439,40 @@ class HostShuffleCoalesceIterator( numTablesInBatch > 0 } - override def next(): HostConcatResult = { + override def next(): C = { if (!hasNext()) { throw new NoSuchElementException("No more host batches to concatenate") } concatenateTablesInHost() } - private def canAddToBatch(nextTable: SerializedTableHeader): Boolean = { - if (batchByteSize + nextTable.getDataLen > targetBatchByteSize) { + private def canAddToBatch(nextTable: T): Boolean = { + if (batchByteSize + tableOperator.getDataLen(nextTable) > targetBatchByteSize) { return false } - if (numRowsInBatch.toLong + nextTable.getNumRows > Integer.MAX_VALUE) { + if (numRowsInBatch.toLong + tableOperator.getNumRows(nextTable) > Integer.MAX_VALUE) { return false } true } } +class HostShuffleCoalesceIterator( + iter: Iterator[ColumnarBatch], + targetBatchByteSize: Long, + metricsMap: Map[String, GpuMetric]) + extends HostCoalesceIteratorBase[SerializedTableColumn, HostConcatResult](iter, + targetBatchByteSize, metricsMap) { + override protected def tableOperator = new CudfTableOperator +} + /** - * Iterator that coalesces columnar batches that are expected to only contain - * [[SerializedTableColumn]]. The serialized tables within are collected up - * to the target batch size and then concatenated on the host before the data - * is transferred to the GPU. + * Iterator that expects only the coalesced host buffers as the input, and transfers + * the host buffers to GPU. */ -class GpuShuffleCoalesceIterator(iter: Iterator[HostConcatResult], - dataTypes: Array[DataType], - metricsMap: Map[String, GpuMetric]) - extends Iterator[ColumnarBatch] { +class GpuShuffleCoalesceIterator(iter: Iterator[AnyRef], + dataTypes: Array[DataType], + metricsMap: Map[String, GpuMetric]) extends Iterator[ColumnarBatch] { private[this] val opTimeMetric = metricsMap(GpuMetric.OP_TIME) private[this] val outputBatchesMetric = metricsMap(GpuMetric.NUM_OUTPUT_BATCHES) private[this] val outputRowsMetric = metricsMap(GpuMetric.NUM_OUTPUT_ROWS) @@ -218,15 +491,14 @@ class GpuShuffleCoalesceIterator(iter: Iterator[HostConcatResult], iter.next() } - withResource(hostConcatResult) { _ => + withResourceIfAllowed(hostConcatResult) { _ => // We acquire the GPU regardless of whether `hostConcatResult` // is an empty batch or not, because the downstream tasks expect // the `GpuShuffleCoalesceIterator` to acquire the semaphore and may // generate GPU data from batches that are empty. GpuSemaphore.acquireIfNecessary(TaskContext.get()) - withResource(new MetricRange(opTimeMetric)) { _ => - val batch = cudf_utils.HostConcatResultUtil.getColumnarBatch(hostConcatResult, dataTypes) + val batch = GpuShuffleCoalesceUtils.toGpuIfAllowed(hostConcatResult, dataTypes) outputBatchesMetric += 1 outputRowsMetric += batch.numRows() batch diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala index b4841046acc..de08f715cd5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala @@ -19,7 +19,6 @@ package com.nvidia.spark.rapids import scala.collection.mutable import ai.rapids.cudf.{NvtxColor, NvtxRange} -import ai.rapids.cudf.JCudfSerialization.HostConcatResult import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit import com.nvidia.spark.rapids.shims.{GpuHashPartitioning, ShimBinaryExecNode} @@ -72,6 +71,7 @@ class GpuShuffledHashJoinMeta( val Seq(left, right) = childPlans.map(_.convertIfNeeded()) val useSizedJoin = GpuShuffledSizedHashJoinExec.useSizedJoin(conf, join.joinType, join.leftKeys, join.rightKeys) + val readOpt = CoalesceReadOption(conf) val joinExec = join.joinType match { case LeftOuter | RightOuter if useSizedJoin => GpuShuffledAsymmetricHashJoinExec( @@ -83,6 +83,7 @@ class GpuShuffledHashJoinMeta( right, conf.isGPUShuffle, conf.gpuTargetBatchSizeBytes, + readOpt, isSkewJoin = false)( join.leftKeys, join.rightKeys, @@ -97,6 +98,7 @@ class GpuShuffledHashJoinMeta( right, conf.isGPUShuffle, conf.gpuTargetBatchSizeBytes, + readOpt, isSkewJoin = false)( join.leftKeys, join.rightKeys) @@ -285,22 +287,20 @@ object GpuShuffledHashJoinExec extends Logging { val buildTypes = buildOutput.map(_.dataType).toArray closeOnExcept(new CloseableBufferedIterator(buildIter)) { bufBuildIter => val startTime = System.nanoTime() + var isBuildSerialized = false // Batches type detection - val isBuildSerialized = bufBuildIter.hasNext && isBatchSerialized(bufBuildIter.head) - - // Let batches coalesce for size overflow check - val coalesceBuiltIter = if (isBuildSerialized) { - new HostShuffleCoalesceIterator(bufBuildIter, targetSize, coalesceMetrics) - } else { // Batches on GPU have already coalesced to the target size by the given goal. - bufBuildIter - } - + val coalesceBuiltIter = GpuShuffleCoalesceUtils.getHostShuffleCoalesceIterator( + bufBuildIter, targetSize, coalesceMetrics).map { iter => + isBuildSerialized = true + iter + }.getOrElse(bufBuildIter) if (coalesceBuiltIter.hasNext) { val firstBuildBatch = coalesceBuiltIter.next() // Batches have coalesced to the target size, so size will overflow if there are // more than one batch, or the first batch size already exceeds the target. val sizeOverflow = closeOnExcept(firstBuildBatch) { _ => - coalesceBuiltIter.hasNext || getBatchSize(firstBuildBatch) > targetSize + coalesceBuiltIter.hasNext || + GpuShuffleCoalesceUtils.getCoalescedBufferSize(firstBuildBatch) > targetSize } val needSingleBuildBatch = !subPartConf.getOrElse(sizeOverflow) if (needSingleBuildBatch && isBuildSerialized && !sizeOverflow) { @@ -309,11 +309,11 @@ object GpuShuffledHashJoinExec extends Logging { // It can be optimized for grabbing the GPU semaphore when there is only a single // serialized host batch and the sub-partitioning is not activated. val (singleBuildCb, bufferedStreamIter) = getBuildBatchOptimizedAndClose( - firstBuildBatch.asInstanceOf[HostConcatResult], streamIter, buildTypes, - buildGoal, buildTime) + firstBuildBatch, streamIter, buildTypes, buildGoal, buildTime) logDebug("In the optimized case for grabbing the GPU semaphore, return " + - s"a single batch (size: ${getBatchSize(singleBuildCb)}) for the build side " + - s"with $buildGoal goal.") + s"a single batch (size: " + + s"${GpuShuffleCoalesceUtils.getCoalescedBufferSize(singleBuildCb)}) for the " + + s"build side with $buildGoal goal.") (Left(singleBuildCb), bufferedStreamIter) } else { // Other cases without optimization @@ -321,8 +321,7 @@ object GpuShuffledHashJoinExec extends Logging { coalesceBuiltIter val gpuBuildIter = if (isBuildSerialized) { // batches on host, move them to GPU - new GpuShuffleCoalesceIterator(safeIter.asInstanceOf[Iterator[HostConcatResult]], - buildTypes, coalesceMetrics) + new GpuShuffleCoalesceIterator(safeIter, buildTypes, coalesceMetrics) } else { // batches already on GPU safeIter.asInstanceOf[Iterator[ColumnarBatch]] } @@ -334,8 +333,9 @@ object GpuShuffledHashJoinExec extends Logging { }.getOrElse { if (needSingleBuildBatch) { val oneCB = getAsSingleBatch(gpuBuildIter, buildOutput, buildGoal, coalesceMetrics) - logDebug(s"Return a single batch (size: ${getBatchSize(oneCB)}) for the " + - s"build side with $buildGoal goal.") + logDebug(s"Return a single batch (size: " + + s"${GpuShuffleCoalesceUtils.getCoalescedBufferSize(oneCB)}) for the build " + + s"side with $buildGoal goal.") Left(oneCB) } else { logDebug("Return multiple batches as the build side data for the following " + @@ -411,16 +411,8 @@ object GpuShuffledHashJoinExec extends Logging { } } - /** Only accepts a HostConcatResult or a ColumnarBatch as input */ - private def getBatchSize(maybeBatch: AnyRef): Long = maybeBatch match { - case batch: ColumnarBatch => GpuColumnVector.getTotalDeviceMemoryUsed(batch) - case hostBatch: HostConcatResult => hostBatch.getTableHeader().getDataLen() - case _ => throw new IllegalStateException(s"Expect a HostConcatResult or a " + - s"ColumnarBatch, but got a ${maybeBatch.getClass.getSimpleName}") - } - private def getBuildBatchOptimizedAndClose( - hostConcatResult: HostConcatResult, + hostConcatResult: AutoCloseable, streamIter: Iterator[ColumnarBatch], buildDataTypes: Array[DataType], buildGoal: CoalesceSizeGoal, @@ -441,8 +433,7 @@ object GpuShuffledHashJoinExec extends Logging { } // Bring the build batch to the GPU now val buildBatch = buildTime.ns { - val cb = - cudf_utils.HostConcatResultUtil.getColumnarBatch(hostConcatResult, buildDataTypes) + val cb = GpuShuffleCoalesceUtils.toGpuIfAllowed(hostConcatResult, buildDataTypes) getFilterFunc(buildGoal).map(filterAndClose => filterAndClose(cb)).getOrElse(cb) } (buildBatch, bufStreamIter) @@ -462,9 +453,5 @@ object GpuShuffledHashJoinExec extends Logging { "single build batch") ConcatAndConsumeAll.getSingleBatchWithVerification(singleBatchIter, inputAttrs) } - - def isBatchSerialized(batch: ColumnarBatch): Boolean = { - batch.numCols() == 1 && batch.column(0).isInstanceOf[SerializedTableColumn] - } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala index 4d06bdf0553..f37b5810ae6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledSizedHashJoinExec.scala @@ -223,6 +223,8 @@ object GpuShuffledSizedHashJoinExec { * grabbing the GPU semaphore. */ trait HostHostJoinSizer extends JoinSizer[SpillableHostConcatResult] { + def readOption: CoalesceReadOption + override def setupForProbe( iter: Iterator[ColumnarBatch]): Iterator[SpillableHostConcatResult] = { new SpillableHostConcatResultFromColumnarBatchIterator(iter) @@ -235,24 +237,21 @@ object GpuShuffledSizedHashJoinExec { gpuBatchSizeBytes: Long, metrics: Map[String, GpuMetric]): Iterator[ColumnarBatch] = { val concatMetrics = getConcatMetrics(metrics) - val bufferedCoalesceIter = new CloseableBufferedIterator( - new HostShuffleCoalesceIterator( - new HostQueueBatchIterator(queue, remainingIter), - gpuBatchSizeBytes, - concatMetrics)) - withResource(new NvtxRange("fetch first batch", NvtxColor.YELLOW)) { _ => - // Force a coalesce of the first batch before we grab the GPU semaphore - bufferedCoalesceIter.headOption - } - new GpuShuffleCoalesceIterator(bufferedCoalesceIter, batchTypes, concatMetrics) + GpuShuffleCoalesceUtils.getGpuShuffleCoalesceIterator( + new HostQueueBatchIterator(queue, remainingIter), + gpuBatchSizeBytes, + batchTypes, + readOption, + concatMetrics, + prefetchFirstBatch = true) } override def getProbeBatchRowCount(batch: SpillableHostConcatResult): Long = { - batch.header.getNumRows + batch.getNumRows } override def getProbeBatchDataSize(batch: SpillableHostConcatResult): Long = { - batch.header.getDataLen + batch.getDataLen } } @@ -265,6 +264,8 @@ object GpuShuffledSizedHashJoinExec { * See https://github.com/NVIDIA/spark-rapids/issues/11322. */ trait HostHostUnspillableJoinSizer extends JoinSizer[ColumnarBatch] { + def readOption: CoalesceReadOption + override def setupForProbe( iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = iter @@ -275,31 +276,25 @@ object GpuShuffledSizedHashJoinExec { gpuBatchSizeBytes: Long, metrics: Map[String, GpuMetric]): Iterator[ColumnarBatch] = { val concatMetrics = getConcatMetrics(metrics) - val bufferedCoalesceIter = new CloseableBufferedIterator( - new HostShuffleCoalesceIterator( - queue.iterator ++ remainingIter, - gpuBatchSizeBytes, - concatMetrics)) - withResource(new NvtxRange("fetch first batch", NvtxColor.YELLOW)) { _ => - // Force a coalesce of the first batch before we grab the GPU semaphore - bufferedCoalesceIter.headOption - } - new GpuShuffleCoalesceIterator(bufferedCoalesceIter, batchTypes, concatMetrics) + GpuShuffleCoalesceUtils.getGpuShuffleCoalesceIterator( + queue.iterator ++ remainingIter, + gpuBatchSizeBytes, + batchTypes, + readOption, + concatMetrics, + prefetchFirstBatch = true) } override def getProbeBatchRowCount(batch: ColumnarBatch): Long = batch.numRows() override def getProbeBatchDataSize(batch: ColumnarBatch): Long = { - SerializedTableColumn.getMemoryUsed(batch) + GpuShuffleCoalesceUtils.getSerializedBufferSize(batch) } } /** * Join sizer to use when at least one side of the join is coming from another GPU exec node * such that the GPU semaphore is already held. Caches input batches on the GPU. - * - * @param startWithLeftSide whether to prefer fetching from the left or right side first - * when probing for table sizes. */ trait SpillableColumnarBatchJoinSizer extends JoinSizer[SpillableColumnarBatch] { override def setupForProbe(iter: Iterator[ColumnarBatch]): Iterator[SpillableColumnarBatch] = { @@ -377,8 +372,10 @@ abstract class GpuShuffledSizedHashJoinExec[HOST_BATCH_TYPE <: AutoCloseable] ex def isSkewJoin: Boolean def cpuLeftKeys: Seq[Expression] def cpuRightKeys: Seq[Expression] + def readOption: CoalesceReadOption - protected def createHostHostSizer(): JoinSizer[HOST_BATCH_TYPE] + protected def createHostHostSizer( + readOption: CoalesceReadOption): JoinSizer[HOST_BATCH_TYPE] protected def createSpillableColumnarBatchSizer( startWithLeftSide: Boolean): JoinSizer[SpillableColumnarBatch] @@ -425,20 +422,21 @@ abstract class GpuShuffledSizedHashJoinExec[HOST_BATCH_TYPE <: AutoCloseable] ex val localCondition = condition val localGpuBatchSizeBytes = gpuBatchSizeBytes val localMetrics = allMetrics.withDefaultValue(NoopMetric) + val localReadOption = readOption left.executeColumnar().zipPartitions(right.executeColumnar()) { case (leftIter, rightIter) => val joinInfo = (isLeftHost, isRightHost) match { case (true, true) => getHostHostJoinInfo(localJoinType, localLeftKeys, leftOutput, leftIter, - localRightKeys, rightOutput, rightIter, - localCondition, localGpuBatchSizeBytes, localMetrics) + localRightKeys, rightOutput, rightIter, localCondition, + localGpuBatchSizeBytes, localReadOption, localMetrics) case (true, false) => getHostGpuJoinInfo(localJoinType, localLeftKeys, leftOutput, leftIter, - localRightKeys, rightOutput, rightIter, - localCondition, localGpuBatchSizeBytes, localMetrics) + localRightKeys, rightOutput, rightIter, localCondition, + localGpuBatchSizeBytes, localReadOption, localMetrics) case (false, true) => getGpuHostJoinInfo(localJoinType, localLeftKeys, leftOutput, leftIter, - localRightKeys, rightOutput, rightIter, - localCondition, localGpuBatchSizeBytes, localMetrics) + localRightKeys, rightOutput, rightIter, localCondition, + localGpuBatchSizeBytes, localReadOption, localMetrics) case (false, false) => getGpuGpuJoinInfo(localJoinType, localLeftKeys, leftOutput, leftIter, localRightKeys, rightOutput, rightIter, @@ -539,8 +537,9 @@ abstract class GpuShuffledSizedHashJoinExec[HOST_BATCH_TYPE <: AutoCloseable] ex rightIter: Iterator[ColumnarBatch], condition: Option[Expression], gpuBatchSizeBytes: Long, + readOption: CoalesceReadOption, metrics: Map[String, GpuMetric]): JoinInfo = { - val sizer = createHostHostSizer() + val sizer = createHostHostSizer(readOption) sizer.getJoinInfo(joinType, leftKeys, leftOutput, leftIter, rightKeys, rightOutput, rightIter, condition, gpuBatchSizeBytes, metrics) } @@ -559,12 +558,15 @@ abstract class GpuShuffledSizedHashJoinExec[HOST_BATCH_TYPE <: AutoCloseable] ex rightIter: Iterator[ColumnarBatch], condition: Option[Expression], gpuBatchSizeBytes: Long, + readOption: CoalesceReadOption, metrics: Map[String, GpuMetric]): JoinInfo = { val sizer = createSpillableColumnarBatchSizer(startWithLeftSide = true) val concatMetrics = getConcatMetrics(metrics) - val leftIter = new GpuShuffleCoalesceIterator( - new HostShuffleCoalesceIterator(rawLeftIter, gpuBatchSizeBytes, concatMetrics), + val leftIter = GpuShuffleCoalesceUtils.getGpuShuffleCoalesceIterator( + rawLeftIter, + gpuBatchSizeBytes, leftOutput.map(_.dataType).toArray, + readOption, concatMetrics) sizer.getJoinInfo(joinType, leftKeys, leftOutput, leftIter, rightKeys, rightOutput, rightIter, condition, gpuBatchSizeBytes, metrics) @@ -584,12 +586,15 @@ abstract class GpuShuffledSizedHashJoinExec[HOST_BATCH_TYPE <: AutoCloseable] ex rawRightIter: Iterator[ColumnarBatch], condition: Option[Expression], gpuBatchSizeBytes: Long, + readOption: CoalesceReadOption, metrics: Map[String, GpuMetric]): JoinInfo = { val sizer = createSpillableColumnarBatchSizer(startWithLeftSide = false) val concatMetrics = getConcatMetrics(metrics) - val rightIter = new GpuShuffleCoalesceIterator( - new HostShuffleCoalesceIterator(rawRightIter, gpuBatchSizeBytes, concatMetrics), + val rightIter = GpuShuffleCoalesceUtils.getGpuShuffleCoalesceIterator( + rawRightIter, + gpuBatchSizeBytes, rightOutput.map(_.dataType).toArray, + readOption, concatMetrics) sizer.getJoinInfo(joinType, leftKeys, leftOutput, leftIter, rightKeys, rightOutput, rightIter, condition, gpuBatchSizeBytes, metrics) @@ -728,8 +733,9 @@ object GpuShuffledSymmetricHashJoinExec { } } - class HostHostSymmetricJoinSizer extends SymmetricJoinSizer[SpillableHostConcatResult] - with HostHostJoinSizer { + class HostHostSymmetricJoinSizer(override val readOption: CoalesceReadOption) + extends SymmetricJoinSizer[SpillableHostConcatResult] with HostHostJoinSizer { + override val startWithLeftSide: Boolean = true } @@ -762,6 +768,7 @@ case class GpuShuffledSymmetricHashJoinExec( override val right: SparkPlan, override val isGpuShuffle: Boolean, override val gpuBatchSizeBytes: Long, + override val readOption: CoalesceReadOption, override val isSkewJoin: Boolean)( override val cpuLeftKeys: Seq[Expression], override val cpuRightKeys: Seq[Expression]) @@ -771,8 +778,9 @@ case class GpuShuffledSymmetricHashJoinExec( override def otherCopyArgs: Seq[AnyRef] = Seq(cpuLeftKeys, cpuRightKeys) - override protected def createHostHostSizer(): JoinSizer[SpillableHostConcatResult] = { - new HostHostSymmetricJoinSizer() + override protected def createHostHostSizer( + readOption: CoalesceReadOption): JoinSizer[SpillableHostConcatResult] = { + new HostHostSymmetricJoinSizer(readOption) } override protected def createSpillableColumnarBatchSizer( @@ -1022,7 +1030,9 @@ object GpuShuffledAsymmetricHashJoinExec { } } - class HostHostAsymmetricJoinSizer(override val magnificationThreshold: Int) + class HostHostAsymmetricJoinSizer( + override val magnificationThreshold: Int, + override val readOption: CoalesceReadOption) extends AsymmetricJoinSizer[ColumnarBatch] with HostHostUnspillableJoinSizer { } @@ -1055,6 +1065,7 @@ case class GpuShuffledAsymmetricHashJoinExec( override val right: SparkPlan, override val isGpuShuffle: Boolean, override val gpuBatchSizeBytes: Long, + override val readOption: CoalesceReadOption, override val isSkewJoin: Boolean)( override val cpuLeftKeys: Seq[Expression], override val cpuRightKeys: Seq[Expression], @@ -1064,8 +1075,9 @@ case class GpuShuffledAsymmetricHashJoinExec( override def otherCopyArgs: Seq[AnyRef] = Seq(cpuLeftKeys, cpuRightKeys, magnificationThreshold) - override protected def createHostHostSizer(): JoinSizer[ColumnarBatch] = { - new HostHostAsymmetricJoinSizer(magnificationThreshold) + override protected def createHostHostSizer( + readOption: CoalesceReadOption): JoinSizer[ColumnarBatch] = { + new HostHostAsymmetricJoinSizer(magnificationThreshold, readOption) } override protected def createSpillableColumnarBatchSizer( @@ -1077,19 +1089,14 @@ case class GpuShuffledAsymmetricHashJoinExec( /** * A spillable form of a HostConcatResult. Takes ownership of the specified host buffer. */ -class SpillableHostConcatResult( - val header: SerializedTableHeader, - hmb: HostMemoryBuffer) extends AutoCloseable { - private var buffer = { - SpillableHostBuffer(hmb, hmb.getLength, SpillPriorities.ACTIVE_BATCHING_PRIORITY) - } +sealed trait SpillableHostConcatResult extends AutoCloseable { + def hmb: HostMemoryBuffer + def toBatch: ColumnarBatch + def getNumRows: Long + def getDataLen: Long - def getHostMemoryBufferAndClose(): HostMemoryBuffer = { - val hostBuffer = buffer.getHostBuffer() - closeOnExcept(hostBuffer) { _ => - close() - } - hostBuffer + protected var buffer = { + SpillableHostBuffer(hmb, hmb.getLength, SpillPriorities.ACTIVE_BATCHING_PRIORITY) } override def close(): Unit = { @@ -1098,6 +1105,35 @@ class SpillableHostConcatResult( } } +class CudfSpillableHostConcatResult( + header: SerializedTableHeader, + val hmb: HostMemoryBuffer) extends SpillableHostConcatResult { + + override def toBatch: ColumnarBatch = { + closeOnExcept(buffer.getHostBuffer()) { hostBuf => + SerializedTableColumn.from(header, hostBuf) + } + } + + override def getNumRows: Long = header.getNumRows + + override def getDataLen: Long = header.getDataLen +} + +object SpillableHostConcatResult { + def from(batch: ColumnarBatch): SpillableHostConcatResult = { + require(batch.numCols() > 0, "Batch must have at least 1 column") + batch.column(0) match { + case col: SerializedTableColumn => + val buffer = col.hostBuffer + buffer.incRefCount() + new CudfSpillableHostConcatResult(col.header, buffer) + case c => + throw new IllegalStateException(s"Expected SerializedTableColumn, got ${c.getClass}") + } + } +} + /** * Converts an iterator of shuffle batches in host memory into an iterator of spillable * host memory batches. @@ -1107,17 +1143,7 @@ class SpillableHostConcatResultFromColumnarBatchIterator( override def hasNext: Boolean = iter.hasNext override def next(): SpillableHostConcatResult = { - withResource(iter.next()) { batch => - require(batch.numCols() > 0, "Batch must have at least 1 column") - batch.column(0) match { - case col: SerializedTableColumn => - val buffer = col.hostBuffer - buffer.incRefCount() - new SpillableHostConcatResult(col.header, buffer) - case c => - throw new IllegalStateException(s"Expected SerializedTableColumn, got ${c.getClass}") - } - } + withResource(iter.next())(SpillableHostConcatResult.from) } } @@ -1137,10 +1163,7 @@ class HostQueueBatchIterator( override def next(): ColumnarBatch = { if (spillableQueue.nonEmpty) { - val shcr = spillableQueue.dequeue() - closeOnExcept(shcr.getHostMemoryBufferAndClose()) { hostBuffer => - SerializedTableColumn.from(shcr.header, hostBuffer) - } + withResource(spillableQueue.dequeue())(_.toBatch) } else { batchIter.next() } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortMergeJoinMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortMergeJoinMeta.scala index b7a9fcb9020..7d7adfc5097 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortMergeJoinMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortMergeJoinMeta.scala @@ -84,6 +84,7 @@ class GpuSortMergeJoinMeta( val Seq(left, right) = childPlans.map(_.convertIfNeeded()) val useSizedJoin = GpuShuffledSizedHashJoinExec.useSizedJoin(conf, join.joinType, join.leftKeys, join.rightKeys) + val readOpt = CoalesceReadOption(conf) val joinExec = join.joinType match { case LeftOuter | RightOuter if useSizedJoin => GpuShuffledAsymmetricHashJoinExec( @@ -95,6 +96,7 @@ class GpuSortMergeJoinMeta( right, conf.isGPUShuffle, conf.gpuTargetBatchSizeBytes, + readOpt, join.isSkewJoin)( join.leftKeys, join.rightKeys, @@ -109,6 +111,7 @@ class GpuSortMergeJoinMeta( right, conf.isGPUShuffle, conf.gpuTargetBatchSizeBytes, + readOpt, join.isSkewJoin)( join.leftKeys, join.rightKeys) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 50dc457268c..85de28ade45 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1927,6 +1927,13 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .integerConf .createWithDefault(20) + val SHUFFLE_SPLITRETRY_READ = conf("spark.rapids.shuffle.splitRetryRead.enabled") + .doc("When set to true, use the resizeable shuffle reader who will reduce the " + + "target batch size by half when getting OOM when doing coalescing shuffle read.") + .internal() + .booleanConf + .createWithDefault(true) + // ALLUXIO CONFIGS val ALLUXIO_MASTER = conf("spark.rapids.alluxio.master") .doc("The Alluxio master hostname. If not set, read Alluxio master URL from " + @@ -3217,6 +3224,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val caseWhenFuseEnabled: Boolean = get(CASE_WHEN_FUSE) + lazy val shuffleSplitRetryReadEnabled: Boolean = get(SHUFFLE_SPLITRETRY_READ) + private val optimizerDefaults = Map( // this is not accurate because CPU projections do have a cost due to appending values // to each row that is produced, but this needs to be a really small number because diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RmmRapidsRetryIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RmmRapidsRetryIterator.scala index d86aa596325..986790f4410 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RmmRapidsRetryIterator.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RmmRapidsRetryIterator.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -631,7 +631,7 @@ object RmmRapidsRetryIterator extends Logging { clearInjectedOOMIfNeeded() // make sure we add any prior exceptions to this one as causes - if (lastException != null) { + if (lastException != null && lastException != ex) { ex.addSuppressed(lastException) } lastException = ex diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceReaderRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceReaderRetrySuite.scala new file mode 100644 index 00000000000..4b8f8837b32 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceReaderRetrySuite.scala @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream} + +import ai.rapids.cudf.{HostColumnVector, HostMemoryBuffer, JCudfSerialization} +import ai.rapids.cudf.JCudfSerialization.SerializedTableHeader +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.jni.RmmSpark + +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.vectorized.ColumnarBatch + +class GpuShuffleCoalesceReaderRetrySuite extends RmmSparkRetrySuiteBase { + + private def serializedOneIntColumnBatch(ints: Int*): ColumnarBatch = { + val outStream = new ByteArrayOutputStream(100) + withResource(HostColumnVector.fromInts(ints: _*)) { col => + JCudfSerialization.writeToStream(Array(col), outStream, 0, col.getRowCount) + } + val inStream = new DataInputStream(new ByteArrayInputStream(outStream.toByteArray)) + val header = new SerializedTableHeader(inStream) + closeOnExcept(HostMemoryBuffer.allocate(header.getDataLen, false)) { hostBuffer => + JCudfSerialization.readTableIntoBuffer(inStream, header, hostBuffer) + SerializedTableColumn.from(header, hostBuffer) + } + } + + private def serializedBatches = Seq( + serializedOneIntColumnBatch(1), + serializedOneIntColumnBatch(3), + serializedOneIntColumnBatch(2), + serializedOneIntColumnBatch(5), + serializedOneIntColumnBatch(4)) + + test("GpuShuffleCoalesceReader split-retry") { + val iter = closeOnExcept(serializedBatches) { serBatches => + val reader = new GpuShuffleCoalesceReader( + Iterator(serBatches: _*), + targetBatchSize = 390, // each is 64 due to padding, then total is 320 (=64x5) + dataTypes = Array(IntegerType), + metricsMap = Map.empty.withDefaultValue(NoopMetric)) + RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId) + reader.asIterator + } + withResource(iter.toSeq) { batches => + // 2 batches because of the split-retry + assertResult(expected = 2)(batches.length) + // still 5 rows + assertResult(expected = 5)(batches.map(_.numRows()).sum) + } + } + +}