Skip to content

Commit

Permalink
[apache#1887] improvement: reject all requests from unregistered apps…
Browse files Browse the repository at this point in the history
… in shuffle server (apache#1923)

### What changes were proposed in this pull request?
Reject all requests from unregistered apps in shuffle server

### Why are the changes needed?
For better performance.
Fix: apache#1887

### Does this PR introduce any user-facing change?
No.

### How was this patch tested?
Existing UT
  • Loading branch information
xianjingfeng authored Jul 23, 2024
1 parent 78df6c1 commit 722d307
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ public void shuffleResultTest() throws Exception {
grpcShuffleServerClient.reportShuffleResult(request);
fail("Exception should be thrown");
} catch (Exception e) {
assertTrue(e.getMessage().contains("error happened when report shuffle result"));
assertTrue(e.getMessage().contains("NO_REGISTER"));
}

RssGetShuffleResultRequest req =
Expand All @@ -268,7 +268,7 @@ public void shuffleResultTest() throws Exception {
grpcShuffleServerClient.getShuffleResult(req);
fail("Exception should be thrown");
} catch (Exception e) {
assertTrue(e.getMessage().contains("Can't get shuffle result"));
assertTrue(e.getMessage().contains("NO_REGISTER"));
}

RssRegisterShuffleRequest rrsr =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,28 @@ public void unregisterShuffleByAppId(
RssProtos.ShuffleUnregisterByAppIdRequest request,
StreamObserver<RssProtos.ShuffleUnregisterByAppIdResponse> responseStreamObserver) {
String appId = request.getAppId();

StatusCode result = StatusCode.SUCCESS;
StatusCode status = verifyRequest(appId);
if (status != StatusCode.SUCCESS) {
RssProtos.ShuffleUnregisterByAppIdResponse reply =
RssProtos.ShuffleUnregisterByAppIdResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(status.toString())
.build();
responseStreamObserver.onNext(reply);
responseStreamObserver.onCompleted();
return;
}
String responseMessage = "OK";
try {
shuffleServer.getShuffleTaskManager().removeShuffleDataAsync(appId);

} catch (Exception e) {
result = StatusCode.INTERNAL_ERROR;
status = StatusCode.INTERNAL_ERROR;
}

RssProtos.ShuffleUnregisterByAppIdResponse reply =
RssProtos.ShuffleUnregisterByAppIdResponse.newBuilder()
.setStatus(result.toProto())
.setStatus(status.toProto())
.setRetMsg(responseMessage)
.build();
responseStreamObserver.onNext(reply);
Expand All @@ -129,19 +138,29 @@ public void unregisterShuffle(
RssProtos.ShuffleUnregisterRequest request,
StreamObserver<RssProtos.ShuffleUnregisterResponse> responseStreamObserver) {
String appId = request.getAppId();
StatusCode status = verifyRequest(appId);
if (status != StatusCode.SUCCESS) {
RssProtos.ShuffleUnregisterResponse reply =
RssProtos.ShuffleUnregisterResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(status.toString())
.build();
responseStreamObserver.onNext(reply);
responseStreamObserver.onCompleted();
return;
}
int shuffleId = request.getShuffleId();

StatusCode result = StatusCode.SUCCESS;
String responseMessage = "OK";
try {
shuffleServer.getShuffleTaskManager().removeShuffleDataAsync(appId, shuffleId);
} catch (Exception e) {
result = StatusCode.INTERNAL_ERROR;
status = StatusCode.INTERNAL_ERROR;
}

RssProtos.ShuffleUnregisterResponse reply =
RssProtos.ShuffleUnregisterResponse.newBuilder()
.setStatus(result.toProto())
.setStatus(status.toProto())
.setRetMsg(responseMessage)
.build();
responseStreamObserver.onNext(reply);
Expand Down Expand Up @@ -430,12 +449,20 @@ public void sendShuffleData(
@Override
public void commitShuffleTask(
ShuffleCommitRequest req, StreamObserver<ShuffleCommitResponse> responseObserver) {

ShuffleCommitResponse reply;
String appId = req.getAppId();
StatusCode status = verifyRequest(appId);
if (status != StatusCode.SUCCESS) {
ShuffleCommitResponse response =
ShuffleCommitResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(status.toString())
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
return;
}
int shuffleId = req.getShuffleId();

StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
int commitCount = 0;

Expand All @@ -460,7 +487,7 @@ public void commitShuffleTask(
LOG.error(msg, e);
}

reply =
ShuffleCommitResponse reply =
ShuffleCommitResponse.newBuilder()
.setCommitCount(commitCount)
.setStatus(status.toProto())
Expand All @@ -474,8 +501,18 @@ public void commitShuffleTask(
public void finishShuffle(
FinishShuffleRequest req, StreamObserver<FinishShuffleResponse> responseObserver) {
String appId = req.getAppId();
StatusCode status = verifyRequest(appId);
if (status != StatusCode.SUCCESS) {
FinishShuffleResponse response =
FinishShuffleResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(status.toString())
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
return;
}
int shuffleId = req.getShuffleId();
StatusCode status;
String msg = "OK";
String errorMsg =
"Fail to finish shuffle for appId["
Expand Down Expand Up @@ -506,8 +543,18 @@ public void finishShuffle(
public void requireBuffer(
RequireBufferRequest request, StreamObserver<RequireBufferResponse> responseObserver) {
String appId = request.getAppId();
StatusCode status = verifyRequest(appId);
if (status != StatusCode.SUCCESS) {
RequireBufferResponse response =
RequireBufferResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(status.toString())
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
return;
}
long requireBufferId = -1;
StatusCode status = StatusCode.SUCCESS;
try {
if (StringUtils.isEmpty(appId)) {
// To be compatible with older client version
Expand Down Expand Up @@ -548,6 +595,17 @@ public void requireBuffer(
public void appHeartbeat(
AppHeartBeatRequest request, StreamObserver<AppHeartBeatResponse> responseObserver) {
String appId = request.getAppId();
StatusCode status = verifyRequest(appId);
if (status != StatusCode.SUCCESS) {
AppHeartBeatResponse response =
AppHeartBeatResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(status.toString())
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
return;
}
shuffleServer.getShuffleTaskManager().refreshAppId(appId);
AppHeartBeatResponse response =
AppHeartBeatResponse.newBuilder()
Expand All @@ -572,12 +630,22 @@ public void reportShuffleResult(
ReportShuffleResultRequest request,
StreamObserver<ReportShuffleResultResponse> responseObserver) {
String appId = request.getAppId();
StatusCode status = verifyRequest(appId);
if (status != StatusCode.SUCCESS) {
ReportShuffleResultResponse response =
ReportShuffleResultResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(status.toString())
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
return;
}
int shuffleId = request.getShuffleId();
long taskAttemptId = request.getTaskAttemptId();
int bitmapNum = request.getBitmapNum();
Map<Integer, long[]> partitionToBlockIds =
toPartitionBlocksMap(request.getPartitionToBlockIdsList());
StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
ReportShuffleResultResponse reply;
String requestInfo =
Expand Down Expand Up @@ -617,14 +685,24 @@ public void reportShuffleResult(
public void getShuffleResult(
GetShuffleResultRequest request, StreamObserver<GetShuffleResultResponse> responseObserver) {
String appId = request.getAppId();
StatusCode status = verifyRequest(appId);
if (status != StatusCode.SUCCESS) {
GetShuffleResultResponse response =
GetShuffleResultResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(status.toString())
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
return;
}
int shuffleId = request.getShuffleId();
int partitionId = request.getPartitionId();
BlockIdLayout blockIdLayout =
BlockIdLayout.from(
request.getBlockIdLayout().getSequenceNoBits(),
request.getBlockIdLayout().getPartitionIdBits(),
request.getBlockIdLayout().getTaskAttemptIdBits());
StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
GetShuffleResultResponse reply;
byte[] serializedBlockIds = null;
Expand Down Expand Up @@ -665,6 +743,17 @@ public void getShuffleResultForMultiPart(
GetShuffleResultForMultiPartRequest request,
StreamObserver<GetShuffleResultForMultiPartResponse> responseObserver) {
String appId = request.getAppId();
StatusCode status = verifyRequest(appId);
if (status != StatusCode.SUCCESS) {
GetShuffleResultForMultiPartResponse response =
GetShuffleResultForMultiPartResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(status.toString())
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
return;
}
int shuffleId = request.getShuffleId();
List<Integer> partitionsList = request.getPartitionsList();
BlockIdLayout blockIdLayout =
Expand All @@ -673,7 +762,6 @@ public void getShuffleResultForMultiPart(
request.getBlockIdLayout().getPartitionIdBits(),
request.getBlockIdLayout().getTaskAttemptIdBits());

StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
GetShuffleResultForMultiPartResponse reply;
byte[] serializedBlockIds = null;
Expand Down Expand Up @@ -715,6 +803,17 @@ public void getLocalShuffleData(
GetLocalShuffleDataRequest request,
StreamObserver<GetLocalShuffleDataResponse> responseObserver) {
String appId = request.getAppId();
StatusCode status = verifyRequest(appId);
if (status != StatusCode.SUCCESS) {
GetLocalShuffleDataResponse response =
GetLocalShuffleDataResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(status.toString())
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();
return;
}
int shuffleId = request.getShuffleId();
int partitionId = request.getPartitionId();
int partitionNumPerRange = request.getPartitionNumPerRange();
Expand All @@ -732,7 +831,6 @@ public void getLocalShuffleData(
}
String storageType =
shuffleServer.getShuffleServerConf().get(RssBaseConf.RSS_STORAGE_TYPE).name();
StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
GetLocalShuffleDataResponse reply = null;
ShuffleDataResult sdr = null;
Expand Down Expand Up @@ -831,11 +929,21 @@ public void getLocalShuffleIndex(
GetLocalShuffleIndexRequest request,
StreamObserver<GetLocalShuffleIndexResponse> responseObserver) {
String appId = request.getAppId();
StatusCode status = verifyRequest(appId);
if (status != StatusCode.SUCCESS) {
GetLocalShuffleIndexResponse reply =
GetLocalShuffleIndexResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(status.toString())
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
return;
}
int shuffleId = request.getShuffleId();
int partitionId = request.getPartitionId();
int partitionNumPerRange = request.getPartitionNumPerRange();
int partitionNum = request.getPartitionNum();
StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
GetLocalShuffleIndexResponse reply;
String requestInfo =
Expand Down Expand Up @@ -928,6 +1036,17 @@ public void getMemoryShuffleData(
GetMemoryShuffleDataRequest request,
StreamObserver<GetMemoryShuffleDataResponse> responseObserver) {
String appId = request.getAppId();
StatusCode status = verifyRequest(appId);
if (status != StatusCode.SUCCESS) {
GetMemoryShuffleDataResponse reply =
GetMemoryShuffleDataResponse.newBuilder()
.setStatus(status.toProto())
.setRetMsg(status.toString())
.build();
responseObserver.onNext(reply);
responseObserver.onCompleted();
return;
}
int shuffleId = request.getShuffleId();
int partitionId = request.getPartitionId();
long blockId = request.getLastBlockId();
Expand All @@ -943,7 +1062,6 @@ public void getMemoryShuffleData(
ShuffleServerGrpcMetrics.GET_MEMORY_SHUFFLE_DATA_METHOD, transportTime);
}
}
StatusCode status = StatusCode.SUCCESS;
String msg = "OK";
GetMemoryShuffleDataResponse reply;
String requestInfo =
Expand Down Expand Up @@ -1108,4 +1226,12 @@ private List<ShuffleDataBlockSegment> toShuffleDataBlockSegments(
}
return shuffleDataBlockSegments;
}

private StatusCode verifyRequest(String appId) {
if (StringUtils.isNotBlank(appId)
&& shuffleServer.getShuffleTaskManager().isAppExpired(appId)) {
return StatusCode.NO_REGISTER;
}
return StatusCode.SUCCESS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -725,12 +725,12 @@ public void checkResourceStatus() {
}
}

private boolean isAppExpired(String appId) {
if (shuffleTaskInfos.get(appId) == null) {
public boolean isAppExpired(String appId) {
ShuffleTaskInfo shuffleTaskInfo = shuffleTaskInfos.get(appId);
if (shuffleTaskInfo == null) {
return true;
}
return System.currentTimeMillis() - shuffleTaskInfos.get(appId).getCurrentTimes()
> appExpiredWithoutHB;
return System.currentTimeMillis() - shuffleTaskInfo.getCurrentTimes() > appExpiredWithoutHB;
}

/**
Expand Down
Loading

0 comments on commit 722d307

Please sign in to comment.