Skip to content

Commit

Permalink
[#1751][0.9] improvement: support gluten (#1753)
Browse files Browse the repository at this point in the history
* support gluten

* optimize

* fix bug

* nit

* fix spotless

* nit

* nit

* fix bug

* optimize

* optimize

* nit

* nit

* nit

* nit

* nit

* Update RssShuffleWriter.java
  • Loading branch information
xianjingfeng authored Jun 18, 2024
1 parent a6a715f commit 4944d54
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -475,15 +475,6 @@ public <K, V> ShuffleWriter<K, V> getWriter(

int shuffleId = rssHandle.getShuffleId();
String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
ShuffleHandleInfo shuffleHandleInfo;
if (shuffleManagerRpcServiceEnabled) {
// Get the ShuffleServer list from the Driver based on the shuffleId
shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
} else {
shuffleHandleInfo =
new ShuffleHandleInfo(
shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
}
ShuffleWriteMetrics writeMetrics = context.taskMetrics().shuffleWriteMetrics();
return new RssShuffleWriter<>(
rssHandle.getAppId(),
Expand All @@ -496,8 +487,7 @@ public <K, V> ShuffleWriter<K, V> getWriter(
shuffleWriteClient,
rssHandle,
this::markFailedTask,
context,
shuffleHandleInfo);
context);
} else {
throw new RssException("Unexpected ShuffleHandle:" + handle.getClass().getName());
}
Expand Down Expand Up @@ -806,6 +796,18 @@ private ShuffleManagerClient createShuffleManagerClient(String host, int port) {
.createShuffleManagerClient(ClientType.GRPC, host, port);
}

public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?> rssHandle) {
if (shuffleManagerRpcServiceEnabled) {
// Get the ShuffleServer list from the Driver based on the shuffleId
return getRemoteShuffleHandleInfo(rssHandle.getShuffleId());
} else {
return new ShuffleHandleInfo(
rssHandle.getShuffleId(),
rssHandle.getPartitionToServers(),
rssHandle.getRemoteStorage());
}
}

/**
* Get the ShuffleServer list from the Driver based on the shuffleId
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ public RssShuffleWriter(
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
TaskContext context,
ShuffleHandleInfo shuffleHandleInfo) {
TaskContext context) {
this(
appId,
shuffleId,
Expand All @@ -201,9 +200,10 @@ public RssShuffleWriter(
shuffleWriteClient,
rssHandle,
taskFailureCallback,
shuffleHandleInfo,
shuffleManager.getShuffleHandleInfo(rssHandle),
context);
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
ShuffleHandleInfo shuffleHandleInfo = shuffleManager.getShuffleHandleInfo(rssHandle);
final WriteBufferManager bufferManager =
new WriteBufferManager(
shuffleId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ public class RssShuffleManager extends RssShuffleManagerBase {
private boolean rssResubmitStage;

private boolean taskBlockSendFailureRetryEnabled;

private boolean shuffleManagerRpcServiceEnabled;
/** A list of shuffleServer for Write failures */
private Set<String> failuresShuffleServerIds;
Expand Down Expand Up @@ -514,15 +513,6 @@ public <K, V> ShuffleWriter<K, V> getWriter(
} else {
writeMetrics = context.taskMetrics().shuffleWriteMetrics();
}
ShuffleHandleInfo shuffleHandleInfo;
if (shuffleManagerRpcServiceEnabled) {
// Get the ShuffleServer list from the Driver based on the shuffleId
shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
} else {
shuffleHandleInfo =
new ShuffleHandleInfo(
shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
}
String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), rssHandle.getShuffleId());
return new RssShuffleWriter<>(
Expand All @@ -536,8 +526,7 @@ public <K, V> ShuffleWriter<K, V> getWriter(
shuffleWriteClient,
rssHandle,
this::markFailedTask,
context,
shuffleHandleInfo);
context);
}

@Override
Expand Down Expand Up @@ -656,17 +645,7 @@ public <K, C> ShuffleReader<K, C> getReaderImpl(
RssShuffleHandle<K, ?, C> rssShuffleHandle = (RssShuffleHandle<K, ?, C>) handle;
final int partitionNum = rssShuffleHandle.getDependency().partitioner().numPartitions();
int shuffleId = rssShuffleHandle.getShuffleId();
ShuffleHandleInfo shuffleHandleInfo;
if (shuffleManagerRpcServiceEnabled) {
// Get the ShuffleServer list from the Driver based on the shuffleId
shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
} else {
shuffleHandleInfo =
new ShuffleHandleInfo(
shuffleId,
rssShuffleHandle.getPartitionToServers(),
rssShuffleHandle.getRemoteStorage());
}
ShuffleHandleInfo shuffleHandleInfo = getShuffleHandleInfo(rssShuffleHandle);
Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
shuffleHandleInfo.getPartitionToServers();
Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers =
Expand Down Expand Up @@ -1101,6 +1080,18 @@ private ShuffleManagerClient createShuffleManagerClient(String host, int port) {
.createShuffleManagerClient(ClientType.GRPC, host, port);
}

public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?> rssHandle) {
if (shuffleManagerRpcServiceEnabled) {
// Get the ShuffleServer list from the Driver based on the shuffleId
return getRemoteShuffleHandleInfo(rssHandle.getShuffleId());
} else {
return new ShuffleHandleInfo(
rssHandle.getShuffleId(),
rssHandle.getPartitionToServers(),
rssHandle.getRemoteStorage());
}
}

/**
* Get the ShuffleServer list from the Driver based on the shuffleId
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {

private final String appId;
private final int shuffleId;
private final ShuffleHandleInfo shuffleHandleInfo;
private WriteBufferManager bufferManager;
private String taskId;
private final int numMaps;
Expand All @@ -110,7 +111,8 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
private final ShuffleWriteClient shuffleWriteClient;
private final Set<ShuffleServerInfo> shuffleServersForData;
private final long[] partitionLengths;
private final boolean isMemoryShuffleEnabled;
// Gluten needs this variable
protected final boolean isMemoryShuffleEnabled;
private final Function<String, Boolean> taskFailureCallback;
private final Set<Long> blockIds = Sets.newConcurrentHashSet();
private TaskContext taskContext;
Expand Down Expand Up @@ -195,6 +197,7 @@ private RssShuffleWriter(
this.isMemoryShuffleEnabled =
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
this.taskFailureCallback = taskFailureCallback;
this.shuffleHandleInfo = shuffleHandleInfo;
this.taskContext = context;
this.sparkConf = sparkConf;
this.blockFailSentRetryEnabled =
Expand All @@ -204,6 +207,7 @@ private RssShuffleWriter(
RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.defaultValue());
}

// Gluten needs this constructor
public RssShuffleWriter(
String appId,
int shuffleId,
Expand All @@ -215,8 +219,7 @@ public RssShuffleWriter(
ShuffleWriteClient shuffleWriteClient,
RssShuffleHandle<K, V, C> rssHandle,
Function<String, Boolean> taskFailureCallback,
TaskContext context,
ShuffleHandleInfo shuffleHandleInfo) {
TaskContext context) {
this(
appId,
shuffleId,
Expand All @@ -228,7 +231,7 @@ public RssShuffleWriter(
shuffleWriteClient,
rssHandle,
taskFailureCallback,
shuffleHandleInfo,
shuffleManager.getShuffleHandleInfo(rssHandle),
context);
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
final WriteBufferManager bufferManager =
Expand Down Expand Up @@ -264,7 +267,8 @@ public void write(Iterator<Product2<K, V>> records) {
}
}

private void writeImpl(Iterator<Product2<K, V>> records) {
// Gluten needs this method.
protected void writeImpl(Iterator<Product2<K, V>> records) {
List<ShuffleBlockInfo> shuffleBlockInfos;
boolean isCombine = shuffleDependency.mapSideCombine();
Function1<V, C> createCombiner = null;
Expand Down Expand Up @@ -322,6 +326,11 @@ private void writeImpl(Iterator<Product2<K, V>> records) {
+ bufferManager.getManagerCostInfo());
}

// Gluten needs this method
protected void internalCheckBlockSendResult() {
this.checkBlockSendResult(this.blockIds);
}

private void checkSentRecordCount(long recordCount) {
if (recordCount != bufferManager.getRecordCount()) {
String errorMsg =
Expand Down

0 comments on commit 4944d54

Please sign in to comment.