Skip to content

Commit

Permalink
[apache#2094] feat(client): Introduce retry mechanism for coordinator…
Browse files Browse the repository at this point in the history
… client (apache#2095)

### Why are the changes needed?

Fix: apache#2094 

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

UTs
  • Loading branch information
kuszz authored Sep 9, 2024
1 parent 529f0c7 commit a6aefcb
Show file tree
Hide file tree
Showing 16 changed files with 486 additions and 263 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ public boolean sendCommit(
}

@Override
public void registerCoordinators(String coordinators) {}
public void registerCoordinators(String coordinators, long retryIntervalMs, int retryTimes) {}

@Override
public Map<String, String> fetchClientConf(int timeoutMs) {
Expand Down Expand Up @@ -578,7 +578,9 @@ public ShuffleAssignmentsInfo getShuffleAssignments(
Set<String> faultyServerIds,
int stageId,
int stageAttemptNumber,
boolean reassign) {
boolean reassign,
long retryIntervalMs,
int retryTimes) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ public boolean sendCommit(
}

@Override
public void registerCoordinators(String coordinators) {}
public void registerCoordinators(String coordinators, long retryIntervalMs, int retryTimes) {}

@Override
public Map<String, String> fetchClientConf(int timeoutMs) {
Expand Down Expand Up @@ -547,7 +547,9 @@ public ShuffleAssignmentsInfo getShuffleAssignments(
Set<String> faultyServerIds,
int stageId,
int stageAttemptNumber,
boolean reassign) {
boolean reassign,
long retryIntervalMs,
int retryTimes) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.impl.grpc.CoordinatorGrpcRetryableClient;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.client.util.ClientUtils;
Expand Down Expand Up @@ -103,12 +103,15 @@ public static ShuffleManager loadShuffleManager(String name, SparkConf conf, boo
return instance;
}

public static List<CoordinatorClient> createCoordinatorClients(SparkConf sparkConf) {
public static CoordinatorGrpcRetryableClient createCoordinatorClients(SparkConf sparkConf) {
String clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM);
long retryIntervalMs = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
int heartbeatThread = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
CoordinatorClientFactory coordinatorClientFactory = CoordinatorClientFactory.getInstance();
return coordinatorClientFactory.createCoordinatorClient(
ClientType.valueOf(clientType), coordinators);
ClientType.valueOf(clientType), coordinators, retryIntervalMs, retryTimes, heartbeatThread);
}

public static void applyDynamicClientConf(SparkConf sparkConf, Map<String, String> confItems) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.impl.grpc.CoordinatorGrpcRetryableClient;
import org.apache.uniffle.client.request.RssFetchClientConfRequest;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
import org.apache.uniffle.client.response.RssFetchClientConfResponse;
Expand Down Expand Up @@ -399,10 +399,17 @@ protected static long getTaskAttemptIdForBlockId(
protected static void fetchAndApplyDynamicConf(SparkConf sparkConf) {
String clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
long retryIntervalMs = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
int heartbeatThread = sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
CoordinatorClientFactory coordinatorClientFactory = CoordinatorClientFactory.getInstance();
List<CoordinatorClient> coordinatorClients =
CoordinatorGrpcRetryableClient coordinatorClients =
coordinatorClientFactory.createCoordinatorClient(
ClientType.valueOf(clientType), coordinators);
ClientType.valueOf(clientType),
coordinators,
retryIntervalMs,
retryTimes,
heartbeatThread);

int timeoutMs =
sparkConf.getInt(
Expand All @@ -416,18 +423,11 @@ protected static void fetchAndApplyDynamicConf(SparkConf sparkConf) {
}
RssFetchClientConfRequest request =
new RssFetchClientConfRequest(timeoutMs, user, Collections.emptyMap());
for (CoordinatorClient client : coordinatorClients) {
RssFetchClientConfResponse response = client.fetchClientConf(request);
if (response.getStatusCode() == StatusCode.SUCCESS) {
LOG.info("Success to get conf from {}", client.getDesc());
RssSparkShuffleUtils.applyDynamicClientConf(sparkConf, response.getClientConf());
break;
} else {
LOG.warn("Fail to get conf from {}", client.getDesc());
}
RssFetchClientConfResponse response = coordinatorClients.fetchClientConf(request);
if (response.getStatusCode() == StatusCode.SUCCESS) {
RssSparkShuffleUtils.applyDynamicClientConf(sparkConf, response.getClientConf());
}

coordinatorClients.forEach(CoordinatorClient::close);
coordinatorClients.close();
}

@Override
Expand Down Expand Up @@ -902,31 +902,28 @@ private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList());
try {
return RetryUtils.retry(
() -> {
ShuffleAssignmentsInfo response =
shuffleWriteClient.getShuffleAssignments(
id.get(),
shuffleId,
partitionNum,
partitionNumPerRange,
assignmentTags,
assignmentShuffleServerNumber,
estimateTaskConcurrency,
faultyServerIds,
stageId,
stageAttemptNumber,
reassign);
LOG.info("Finished reassign");
if (reassignmentHandler != null) {
response = reassignmentHandler.apply(response);
}
registerShuffleServers(
id.get(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo());
return response.getPartitionToServers();
},
retryInterval,
retryTimes);
ShuffleAssignmentsInfo response =
shuffleWriteClient.getShuffleAssignments(
id.get(),
shuffleId,
partitionNum,
partitionNumPerRange,
assignmentTags,
assignmentShuffleServerNumber,
estimateTaskConcurrency,
faultyServerIds,
stageId,
stageAttemptNumber,
reassign,
retryInterval,
retryTimes);
LOG.info("Finished reassign");
if (reassignmentHandler != null) {
response = reassignmentHandler.apply(response);
}
registerShuffleServers(
id.get(), shuffleId, response.getServerToPartitionRanges(), getRemoteStorageInfo());
return response.getPartitionToServers();
} catch (Throwable throwable) {
throw new RssException("registerShuffle failed!", throwable);
}
Expand All @@ -950,21 +947,23 @@ protected Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(
int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList());
try {
ShuffleAssignmentsInfo response =
shuffleWriteClient.getShuffleAssignments(
appId,
shuffleId,
partitionNum,
partitionNumPerRange,
assignmentTags,
assignmentShuffleServerNumber,
estimateTaskConcurrency,
faultyServerIds,
stageId,
stageAttemptNumber,
reassign,
retryInterval,
retryTimes);
return RetryUtils.retry(
() -> {
ShuffleAssignmentsInfo response =
shuffleWriteClient.getShuffleAssignments(
appId,
shuffleId,
partitionNum,
partitionNumPerRange,
assignmentTags,
assignmentShuffleServerNumber,
estimateTaskConcurrency,
faultyServerIds,
stageId,
stageAttemptNumber,
reassign);
registerShuffleServers(
appId,
shuffleId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.impl.grpc.CoordinatorGrpcRetryableClient;
import org.apache.uniffle.client.request.RssFetchClientConfRequest;
import org.apache.uniffle.client.response.RssFetchClientConfResponse;
import org.apache.uniffle.common.ClientType;
Expand Down Expand Up @@ -647,6 +648,12 @@ void testFetchAndApplyDynamicConf() {
try (MockedStatic<CoordinatorClientFactory> mockFactoryStatic =
mockStatic(CoordinatorClientFactory.class)) {
mockFactoryStatic.when(CoordinatorClientFactory::getInstance).thenReturn(mockFactoryInstance);
long interval = conf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
int retry = conf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
int num = conf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
when(mockFactoryInstance.createCoordinatorClient(
clientType, coordinators, interval, retry, num))
.thenReturn(new CoordinatorGrpcRetryableClient(mockClients, interval, retry, num));
RssShuffleManagerBase.fetchAndApplyDynamicConf(conf);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

package org.apache.spark.shuffle;

import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.security.UserGroupInformation;
Expand All @@ -31,13 +29,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.impl.grpc.CoordinatorGrpcRetryableClient;
import org.apache.uniffle.client.request.RssAccessClusterRequest;
import org.apache.uniffle.client.response.RssAccessClusterResponse;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.RetryUtils;

import static org.apache.uniffle.common.util.Constants.ACCESS_INFO_REQUIRED_SHUFFLE_NODES_NUM;

Expand All @@ -46,7 +43,7 @@ public class DelegationRssShuffleManager implements ShuffleManager {
private static final Logger LOG = LoggerFactory.getLogger(DelegationRssShuffleManager.class);

private final ShuffleManager delegate;
private final List<CoordinatorClient> coordinatorClients;
private final CoordinatorGrpcRetryableClient coordinatorClients;
private final int accessTimeoutMs;
private final SparkConf sparkConf;
private String user;
Expand All @@ -61,7 +58,7 @@ public DelegationRssShuffleManager(SparkConf sparkConf, boolean isDriver) throws
coordinatorClients = RssSparkShuffleUtils.createCoordinatorClients(sparkConf);
delegate = createShuffleManagerInDriver();
} else {
coordinatorClients = Lists.newArrayList();
coordinatorClients = null;
delegate = createShuffleManagerInExecutor();
}

Expand Down Expand Up @@ -127,50 +124,31 @@ private boolean tryAccessCluster() {
extraProperties.put(
ACCESS_INFO_REQUIRED_SHUFFLE_NODES_NUM, String.valueOf(assignmentShuffleNodesNum));

for (CoordinatorClient coordinatorClient : coordinatorClients) {
Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
boolean canAccess;
try {
canAccess =
RetryUtils.retry(
() -> {
RssAccessClusterResponse response =
coordinatorClient.accessCluster(
new RssAccessClusterRequest(
accessId, assignmentTags, accessTimeoutMs, extraProperties, user));
if (response.getStatusCode() == StatusCode.SUCCESS) {
LOG.warn(
"Success to access cluster {} using {}",
coordinatorClient.getDesc(),
accessId);
uuid = response.getUuid();
return true;
} else if (response.getStatusCode() == StatusCode.ACCESS_DENIED) {
throw new RssException(
"Request to access cluster "
+ coordinatorClient.getDesc()
+ " is denied using "
+ accessId
+ " for "
+ response.getMessage());
} else {
throw new RssException(
"Fail to reach cluster "
+ coordinatorClient.getDesc()
+ " for "
+ response.getMessage());
}
},
Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(sparkConf);
try {
if (coordinatorClients != null) {
RssAccessClusterResponse response =
coordinatorClients.accessCluster(
new RssAccessClusterRequest(
accessId, assignmentTags, accessTimeoutMs, extraProperties, user),
retryInterval,
retryTimes);
return canAccess;
} catch (Throwable e) {
LOG.warn(
"Fail to access cluster {} using {} for {}",
coordinatorClient.getDesc(),
accessId,
e.getMessage());
if (response.getStatusCode() == StatusCode.SUCCESS) {
LOG.warn("Success to access cluster using {}", accessId);
uuid = response.getUuid();
return true;
} else if (response.getStatusCode() == StatusCode.ACCESS_DENIED) {
throw new RssException(
"Request to access cluster is denied using "
+ accessId
+ " for "
+ response.getMessage());
} else {
throw new RssException("Fail to reach cluster for " + response.getMessage());
}
}
} catch (Throwable e) {
LOG.warn("Fail to access cluster using {} for ", accessId, e);
}

return false;
Expand Down Expand Up @@ -227,7 +205,9 @@ public boolean unregisterShuffle(int shuffleId) {
@Override
public void stop() {
delegate.stop();
coordinatorClients.forEach(CoordinatorClient::close);
if (coordinatorClients != null) {
coordinatorClients.close();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,10 @@ private void startHeartbeat() {
protected void registerCoordinator() {
String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
LOG.info("Registering coordinators {}", coordinators);
shuffleWriteClient.registerCoordinators(coordinators);
shuffleWriteClient.registerCoordinators(
coordinators,
this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX),
this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX));
}

public CompletableFuture<Long> sendData(AddBlockEvent event) {
Expand Down
Loading

0 comments on commit a6aefcb

Please sign in to comment.