diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 78bcc2c177..45d338e39c 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -475,15 +475,6 @@ public ShuffleWriter getWriter( int shuffleId = rssHandle.getShuffleId(); String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber(); - ShuffleHandleInfo shuffleHandleInfo; - if (shuffleManagerRpcServiceEnabled) { - // Get the ShuffleServer list from the Driver based on the shuffleId - shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId); - } else { - shuffleHandleInfo = - new ShuffleHandleInfo( - shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage()); - } ShuffleWriteMetrics writeMetrics = context.taskMetrics().shuffleWriteMetrics(); return new RssShuffleWriter<>( rssHandle.getAppId(), @@ -496,8 +487,7 @@ public ShuffleWriter getWriter( shuffleWriteClient, rssHandle, this::markFailedTask, - context, - shuffleHandleInfo); + context); } else { throw new RssException("Unexpected ShuffleHandle:" + handle.getClass().getName()); } @@ -806,6 +796,18 @@ private ShuffleManagerClient createShuffleManagerClient(String host, int port) { .createShuffleManagerClient(ClientType.GRPC, host, port); } + public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle rssHandle) { + if (shuffleManagerRpcServiceEnabled) { + // Get the ShuffleServer list from the Driver based on the shuffleId + return getRemoteShuffleHandleInfo(rssHandle.getShuffleId()); + } else { + return new ShuffleHandleInfo( + rssHandle.getShuffleId(), + rssHandle.getPartitionToServers(), + rssHandle.getRemoteStorage()); + } + } + /** * Get the ShuffleServer list from the Driver based on the shuffleId * diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 9e64b2fd5a..37576c1c95 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -188,8 +188,7 @@ public RssShuffleWriter( ShuffleWriteClient shuffleWriteClient, RssShuffleHandle rssHandle, Function taskFailureCallback, - TaskContext context, - ShuffleHandleInfo shuffleHandleInfo) { + TaskContext context) { this( appId, shuffleId, @@ -201,9 +200,10 @@ public RssShuffleWriter( shuffleWriteClient, rssHandle, taskFailureCallback, - shuffleHandleInfo, + shuffleManager.getShuffleHandleInfo(rssHandle), context); BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf); + ShuffleHandleInfo shuffleHandleInfo = shuffleManager.getShuffleHandleInfo(rssHandle); final WriteBufferManager bufferManager = new WriteBufferManager( shuffleId, diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 6d9487ca44..700b7691bf 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -141,7 +141,6 @@ public class RssShuffleManager extends RssShuffleManagerBase { private boolean rssResubmitStage; private boolean taskBlockSendFailureRetryEnabled; - private boolean shuffleManagerRpcServiceEnabled; /** A list of shuffleServer for Write failures */ private Set failuresShuffleServerIds; @@ -514,15 +513,6 @@ public ShuffleWriter getWriter( } else { writeMetrics = context.taskMetrics().shuffleWriteMetrics(); } - ShuffleHandleInfo shuffleHandleInfo; - if (shuffleManagerRpcServiceEnabled) { - // Get the ShuffleServer list from the Driver based on the shuffleId - shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId); - } else { - shuffleHandleInfo = - new ShuffleHandleInfo( - shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage()); - } String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber(); LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), rssHandle.getShuffleId()); return new RssShuffleWriter<>( @@ -536,8 +526,7 @@ public ShuffleWriter getWriter( shuffleWriteClient, rssHandle, this::markFailedTask, - context, - shuffleHandleInfo); + context); } @Override @@ -656,17 +645,7 @@ public ShuffleReader getReaderImpl( RssShuffleHandle rssShuffleHandle = (RssShuffleHandle) handle; final int partitionNum = rssShuffleHandle.getDependency().partitioner().numPartitions(); int shuffleId = rssShuffleHandle.getShuffleId(); - ShuffleHandleInfo shuffleHandleInfo; - if (shuffleManagerRpcServiceEnabled) { - // Get the ShuffleServer list from the Driver based on the shuffleId - shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId); - } else { - shuffleHandleInfo = - new ShuffleHandleInfo( - shuffleId, - rssShuffleHandle.getPartitionToServers(), - rssShuffleHandle.getRemoteStorage()); - } + ShuffleHandleInfo shuffleHandleInfo = getShuffleHandleInfo(rssShuffleHandle); Map> allPartitionToServers = shuffleHandleInfo.getPartitionToServers(); Map> requirePartitionToServers = @@ -1101,6 +1080,18 @@ private ShuffleManagerClient createShuffleManagerClient(String host, int port) { .createShuffleManagerClient(ClientType.GRPC, host, port); } + public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle rssHandle) { + if (shuffleManagerRpcServiceEnabled) { + // Get the ShuffleServer list from the Driver based on the shuffleId + return getRemoteShuffleHandleInfo(rssHandle.getShuffleId()); + } else { + return new ShuffleHandleInfo( + rssHandle.getShuffleId(), + rssHandle.getPartitionToServers(), + rssHandle.getRemoteStorage()); + } + } + /** * Get the ShuffleServer list from the Driver based on the shuffleId * diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 8a22b73ba5..70ae3d8f68 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -95,6 +95,7 @@ public class RssShuffleWriter extends ShuffleWriter { private final String appId; private final int shuffleId; + private final ShuffleHandleInfo shuffleHandleInfo; private WriteBufferManager bufferManager; private String taskId; private final int numMaps; @@ -110,7 +111,8 @@ public class RssShuffleWriter extends ShuffleWriter { private final ShuffleWriteClient shuffleWriteClient; private final Set shuffleServersForData; private final long[] partitionLengths; - private final boolean isMemoryShuffleEnabled; + // Gluten needs this variable + protected final boolean isMemoryShuffleEnabled; private final Function taskFailureCallback; private final Set blockIds = Sets.newConcurrentHashSet(); private TaskContext taskContext; @@ -195,6 +197,7 @@ private RssShuffleWriter( this.isMemoryShuffleEnabled = isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key())); this.taskFailureCallback = taskFailureCallback; + this.shuffleHandleInfo = shuffleHandleInfo; this.taskContext = context; this.sparkConf = sparkConf; this.blockFailSentRetryEnabled = @@ -204,6 +207,7 @@ private RssShuffleWriter( RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.defaultValue()); } + // Gluten needs this constructor public RssShuffleWriter( String appId, int shuffleId, @@ -215,8 +219,7 @@ public RssShuffleWriter( ShuffleWriteClient shuffleWriteClient, RssShuffleHandle rssHandle, Function taskFailureCallback, - TaskContext context, - ShuffleHandleInfo shuffleHandleInfo) { + TaskContext context) { this( appId, shuffleId, @@ -228,7 +231,7 @@ public RssShuffleWriter( shuffleWriteClient, rssHandle, taskFailureCallback, - shuffleHandleInfo, + shuffleManager.getShuffleHandleInfo(rssHandle), context); BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf); final WriteBufferManager bufferManager = @@ -264,7 +267,8 @@ public void write(Iterator> records) { } } - private void writeImpl(Iterator> records) { + // Gluten needs this method. + protected void writeImpl(Iterator> records) { List shuffleBlockInfos; boolean isCombine = shuffleDependency.mapSideCombine(); Function1 createCombiner = null; @@ -322,6 +326,11 @@ private void writeImpl(Iterator> records) { + bufferManager.getManagerCostInfo()); } + // Gluten needs this method + protected void internalCheckBlockSendResult() { + this.checkBlockSendResult(this.blockIds); + } + private void checkSentRecordCount(long recordCount) { if (recordCount != bufferManager.getRecordCount()) { String errorMsg =