Skip to content

Commit

Permalink
[apache#808] feat(spark): ensure thread safe and data consistency whe…
Browse files Browse the repository at this point in the history
…n spill
  • Loading branch information
zuston committed Jul 20, 2023
1 parent d6b43f0 commit dbedf57
Show file tree
Hide file tree
Showing 8 changed files with 475 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ public class RssSparkConfig {
.defaultValue(true)
.withDescription("indicates row based shuffle, set false when use in columnar shuffle");

public static final ConfigOption<Boolean> RSS_MEMORY_SPILL_ENABLED = ConfigOptions
.key("rss.client.memory.spill.enabled")
.booleanType()
.defaultValue(false)
.withDescription("The memory spill switch triggered by Spark TaskMemoryManager, default value is false.");

public static final String SPARK_RSS_CONFIG_PREFIX = "spark.";

public static final ConfigEntry<Integer> RSS_PARTITION_NUM_PER_RANGE =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;

Expand Down Expand Up @@ -87,8 +89,9 @@ public class WriteBufferManager extends MemoryConsumer {
private long requireMemoryInterval;
private int requireMemoryRetryMax;
private Codec codec;
private Function<AddBlockEvent, CompletableFuture<Long>> spillFunc;
private Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc;
private long sendSizeLimit;
private boolean memorySpillEnabled;
private int memorySpillTimeoutSec;
private boolean isRowBased;

Expand Down Expand Up @@ -124,7 +127,7 @@ public WriteBufferManager(
TaskMemoryManager taskMemoryManager,
ShuffleWriteMetrics shuffleWriteMetrics,
RssConf rssConf,
Function<AddBlockEvent, CompletableFuture<Long>> spillFunc) {
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc) {
super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP);
this.bufferSize = bufferManagerOptions.getBufferSize();
this.spillSize = bufferManagerOptions.getBufferSpillThreshold();
Expand Down Expand Up @@ -155,6 +158,7 @@ public WriteBufferManager(
this.spillFunc = spillFunc;
this.sendSizeLimit = rssConf.get(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION);
this.memorySpillTimeoutSec = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
this.memorySpillEnabled = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_ENABLED);
}

/** add serialized columnar data directly when integrate with gluten */
Expand All @@ -165,41 +169,69 @@ public List<ShuffleBlockInfo> addPartitionData(int partitionId, byte[] serialize

public List<ShuffleBlockInfo> addPartitionData(
int partitionId, byte[] serializedData, int serializedDataLength, long start) {
List<ShuffleBlockInfo> result = Lists.newArrayList();
List<ShuffleBlockInfo> candidateSendingBlocks = insertIntoBuffer(partitionId, serializedData, serializedDataLength);

// check buffer size > spill threshold
if (usedBytes.get() - inSendListBytes.get() > spillSize) {
candidateSendingBlocks.addAll(clear());
}
writeTime += System.currentTimeMillis() - start;
return candidateSendingBlocks;
}

/**
* Before inserting a record into its corresponding buffer, the system should check if there is
* sufficient buffer memory available. If there isn't enough memory, it will request additional
* memory from the {@link TaskMemoryManager}. In the event that the JVM is low on memory,
* a spill operation will be triggered. If any memory consumer managed by the {@link TaskMemoryManager}
* fails to meet its memory requirements, it will also be triggered one by one.
*
* If the current buffer manager requests memory and triggers a spill operation,
* the buffer that is currently being held should be dropped, and then re-inserted.
*/
private List<ShuffleBlockInfo> insertIntoBuffer(int partitionId, byte[] serializedData, int serializedDataLength) {
List<ShuffleBlockInfo> sentBlocks = new ArrayList<>();
long required = Math.max(bufferSegmentSize, serializedDataLength);
// Asking memory from task memory manager for the existing writer buffer,
// this may trigger current WriteBufferManager spill method, which will
// make the current write buffer discard. So we have to recheck the buffer existence.
boolean hasRequested = false;
if (buffers.containsKey(partitionId)) {
WriterBuffer wb = buffers.get(partitionId);
if (wb.askForMemory(serializedDataLength)) {
requestMemory(Math.max(bufferSegmentSize, serializedDataLength));
requestMemory(required);
hasRequested = true;
}
}

if (buffers.containsKey(partitionId)) {
if (hasRequested) {
usedBytes.addAndGet(required);
}
WriterBuffer wb = buffers.get(partitionId);
wb.addRecord(serializedData, serializedDataLength);
if (wb.getMemoryUsed() > bufferSize) {
result.add(createShuffleBlock(partitionId, wb));
sentBlocks.add(createShuffleBlock(partitionId, wb));
copyTime += wb.getCopyTime();
buffers.remove(partitionId);
LOG.debug(
"Single buffer is full for shuffleId["
+ shuffleId
+ "] partition["
+ partitionId
+ "] with memoryUsed["
+ wb.getMemoryUsed()
+ "], dataLength["
+ wb.getDataLength()
+ "]");
LOG.debug("Single buffer is full for shuffleId[" + shuffleId
+ "] partition[" + partitionId + "] with memoryUsed[" + wb.getMemoryUsed()
+ "], dataLength[" + wb.getDataLength() + "]");
}
} else {
requestMemory(Math.max(bufferSegmentSize, serializedDataLength));
// The true of hasRequested means the former partitioned buffer has been flushed, that is
// triggered by the spill operation caused by asking for memory. So it needn't to re-request
// the memory.
if (!hasRequested) {
requestMemory(required);
}
usedBytes.addAndGet(required);

WriterBuffer wb = new WriterBuffer(bufferSegmentSize);
wb.addRecord(serializedData, serializedDataLength);
buffers.put(partitionId, wb);
}

// check buffer size > spill threshold
if (usedBytes.get() - inSendListBytes.get() > spillSize) {
result.addAll(clear());
}
writeTime += System.currentTimeMillis() - start;
return result;
return sentBlocks;
}

public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object value) {
Expand Down Expand Up @@ -304,7 +336,6 @@ private void requestMemory(long requiredMem) {
if (allocatedBytes.get() - usedBytes.get() < requiredMem) {
requestExecutorMemory(requiredMem);
}
usedBytes.addAndGet(requiredMem);
requireMemoryTime += System.currentTimeMillis() - start;
}

Expand Down Expand Up @@ -395,7 +426,31 @@ public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo> shuffleBlockI

@Override
public long spill(long size, MemoryConsumer trigger) {
return 0L;
// Only for the MemoryConsumer of this instance, it will flush buffer
if (!memorySpillEnabled || trigger != this) {
return 0L;
}

List<CompletableFuture<Long>> futures = spillFunc.apply(clear());
CompletableFuture<Void> allOfFutures =
CompletableFuture.allOf(futures.toArray(new CompletableFuture[futures.size()]));
try {
allOfFutures.get(memorySpillTimeoutSec, TimeUnit.SECONDS);
} catch (TimeoutException timeoutException) {
// A best effort strategy to wait.
// If timeout exception occurs, the underlying tasks won't be cancelled.
} finally {
long releasedSize = futures.stream().filter(x -> x.isDone()).mapToLong(x -> {
try {
return x.get();
} catch (Exception e) {
return 0;
}
}).sum();
LOG.info("[taskId: {}] Spill triggered by own, released memory size: {}",
taskId, releasedSize);
return releasedSize;
}
}

@VisibleForTesting
Expand Down Expand Up @@ -470,7 +525,7 @@ public void setTaskId(String taskId) {
}

@VisibleForTesting
public void setSpillFunc(Function<AddBlockEvent, CompletableFuture<Long>> spillFunc) {
public void setSpillFunc(Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc) {
this.spillFunc = spillFunc;
}

Expand Down
Loading

0 comments on commit dbedf57

Please sign in to comment.