From 722d3079aad4a8b0a0b2f3eb9b956658889e0aa2 Mon Sep 17 00:00:00 2001 From: xianjingfeng Date: Tue, 23 Jul 2024 15:28:02 +0800 Subject: [PATCH] [#1887] improvement: reject all requests from unregistered apps in shuffle server (#1923) ### What changes were proposed in this pull request? Reject all requests from unregistered apps in shuffle server ### Why are the changes needed? For better performance. Fix: #1887 ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing UT --- .../uniffle/test/ShuffleServerGrpcTest.java | 4 +- .../server/ShuffleServerGrpcService.java | 164 ++++++++++++++++-- .../uniffle/server/ShuffleTaskManager.java | 8 +- .../netty/ShuffleServerNettyHandler.java | 46 ++++- 4 files changed, 192 insertions(+), 30 deletions(-) diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java index 8a6e0cecf0..df3e29971d 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java @@ -259,7 +259,7 @@ public void shuffleResultTest() throws Exception { grpcShuffleServerClient.reportShuffleResult(request); fail("Exception should be thrown"); } catch (Exception e) { - assertTrue(e.getMessage().contains("error happened when report shuffle result")); + assertTrue(e.getMessage().contains("NO_REGISTER")); } RssGetShuffleResultRequest req = @@ -268,7 +268,7 @@ public void shuffleResultTest() throws Exception { grpcShuffleServerClient.getShuffleResult(req); fail("Exception should be thrown"); } catch (Exception e) { - assertTrue(e.getMessage().contains("Can't get shuffle result")); + assertTrue(e.getMessage().contains("NO_REGISTER")); } RssRegisterShuffleRequest rrsr = diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index aea43d24e7..06e2d8b92a 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -105,19 +105,28 @@ public void unregisterShuffleByAppId( RssProtos.ShuffleUnregisterByAppIdRequest request, StreamObserver responseStreamObserver) { String appId = request.getAppId(); - - StatusCode result = StatusCode.SUCCESS; + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + RssProtos.ShuffleUnregisterByAppIdResponse reply = + RssProtos.ShuffleUnregisterByAppIdResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(status.toString()) + .build(); + responseStreamObserver.onNext(reply); + responseStreamObserver.onCompleted(); + return; + } String responseMessage = "OK"; try { shuffleServer.getShuffleTaskManager().removeShuffleDataAsync(appId); } catch (Exception e) { - result = StatusCode.INTERNAL_ERROR; + status = StatusCode.INTERNAL_ERROR; } RssProtos.ShuffleUnregisterByAppIdResponse reply = RssProtos.ShuffleUnregisterByAppIdResponse.newBuilder() - .setStatus(result.toProto()) + .setStatus(status.toProto()) .setRetMsg(responseMessage) .build(); responseStreamObserver.onNext(reply); @@ -129,19 +138,29 @@ public void unregisterShuffle( RssProtos.ShuffleUnregisterRequest request, StreamObserver responseStreamObserver) { String appId = request.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + RssProtos.ShuffleUnregisterResponse reply = + RssProtos.ShuffleUnregisterResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(status.toString()) + .build(); + responseStreamObserver.onNext(reply); + responseStreamObserver.onCompleted(); + return; + } int shuffleId = request.getShuffleId(); - StatusCode result = StatusCode.SUCCESS; String responseMessage = "OK"; try { shuffleServer.getShuffleTaskManager().removeShuffleDataAsync(appId, shuffleId); } catch (Exception e) { - result = StatusCode.INTERNAL_ERROR; + status = StatusCode.INTERNAL_ERROR; } RssProtos.ShuffleUnregisterResponse reply = RssProtos.ShuffleUnregisterResponse.newBuilder() - .setStatus(result.toProto()) + .setStatus(status.toProto()) .setRetMsg(responseMessage) .build(); responseStreamObserver.onNext(reply); @@ -430,12 +449,20 @@ public void sendShuffleData( @Override public void commitShuffleTask( ShuffleCommitRequest req, StreamObserver responseObserver) { - - ShuffleCommitResponse reply; String appId = req.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + ShuffleCommitResponse response = + ShuffleCommitResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(status.toString()) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + return; + } int shuffleId = req.getShuffleId(); - StatusCode status = StatusCode.SUCCESS; String msg = "OK"; int commitCount = 0; @@ -460,7 +487,7 @@ public void commitShuffleTask( LOG.error(msg, e); } - reply = + ShuffleCommitResponse reply = ShuffleCommitResponse.newBuilder() .setCommitCount(commitCount) .setStatus(status.toProto()) @@ -474,8 +501,18 @@ public void commitShuffleTask( public void finishShuffle( FinishShuffleRequest req, StreamObserver responseObserver) { String appId = req.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + FinishShuffleResponse response = + FinishShuffleResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(status.toString()) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + return; + } int shuffleId = req.getShuffleId(); - StatusCode status; String msg = "OK"; String errorMsg = "Fail to finish shuffle for appId[" @@ -506,8 +543,18 @@ public void finishShuffle( public void requireBuffer( RequireBufferRequest request, StreamObserver responseObserver) { String appId = request.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + RequireBufferResponse response = + RequireBufferResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(status.toString()) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + return; + } long requireBufferId = -1; - StatusCode status = StatusCode.SUCCESS; try { if (StringUtils.isEmpty(appId)) { // To be compatible with older client version @@ -548,6 +595,17 @@ public void requireBuffer( public void appHeartbeat( AppHeartBeatRequest request, StreamObserver responseObserver) { String appId = request.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + AppHeartBeatResponse response = + AppHeartBeatResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(status.toString()) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + return; + } shuffleServer.getShuffleTaskManager().refreshAppId(appId); AppHeartBeatResponse response = AppHeartBeatResponse.newBuilder() @@ -572,12 +630,22 @@ public void reportShuffleResult( ReportShuffleResultRequest request, StreamObserver responseObserver) { String appId = request.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + ReportShuffleResultResponse response = + ReportShuffleResultResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(status.toString()) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + return; + } int shuffleId = request.getShuffleId(); long taskAttemptId = request.getTaskAttemptId(); int bitmapNum = request.getBitmapNum(); Map partitionToBlockIds = toPartitionBlocksMap(request.getPartitionToBlockIdsList()); - StatusCode status = StatusCode.SUCCESS; String msg = "OK"; ReportShuffleResultResponse reply; String requestInfo = @@ -617,6 +685,17 @@ public void reportShuffleResult( public void getShuffleResult( GetShuffleResultRequest request, StreamObserver responseObserver) { String appId = request.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + GetShuffleResultResponse response = + GetShuffleResultResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(status.toString()) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + return; + } int shuffleId = request.getShuffleId(); int partitionId = request.getPartitionId(); BlockIdLayout blockIdLayout = @@ -624,7 +703,6 @@ public void getShuffleResult( request.getBlockIdLayout().getSequenceNoBits(), request.getBlockIdLayout().getPartitionIdBits(), request.getBlockIdLayout().getTaskAttemptIdBits()); - StatusCode status = StatusCode.SUCCESS; String msg = "OK"; GetShuffleResultResponse reply; byte[] serializedBlockIds = null; @@ -665,6 +743,17 @@ public void getShuffleResultForMultiPart( GetShuffleResultForMultiPartRequest request, StreamObserver responseObserver) { String appId = request.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + GetShuffleResultForMultiPartResponse response = + GetShuffleResultForMultiPartResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(status.toString()) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + return; + } int shuffleId = request.getShuffleId(); List partitionsList = request.getPartitionsList(); BlockIdLayout blockIdLayout = @@ -673,7 +762,6 @@ public void getShuffleResultForMultiPart( request.getBlockIdLayout().getPartitionIdBits(), request.getBlockIdLayout().getTaskAttemptIdBits()); - StatusCode status = StatusCode.SUCCESS; String msg = "OK"; GetShuffleResultForMultiPartResponse reply; byte[] serializedBlockIds = null; @@ -715,6 +803,17 @@ public void getLocalShuffleData( GetLocalShuffleDataRequest request, StreamObserver responseObserver) { String appId = request.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + GetLocalShuffleDataResponse response = + GetLocalShuffleDataResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(status.toString()) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + return; + } int shuffleId = request.getShuffleId(); int partitionId = request.getPartitionId(); int partitionNumPerRange = request.getPartitionNumPerRange(); @@ -732,7 +831,6 @@ public void getLocalShuffleData( } String storageType = shuffleServer.getShuffleServerConf().get(RssBaseConf.RSS_STORAGE_TYPE).name(); - StatusCode status = StatusCode.SUCCESS; String msg = "OK"; GetLocalShuffleDataResponse reply = null; ShuffleDataResult sdr = null; @@ -831,11 +929,21 @@ public void getLocalShuffleIndex( GetLocalShuffleIndexRequest request, StreamObserver responseObserver) { String appId = request.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + GetLocalShuffleIndexResponse reply = + GetLocalShuffleIndexResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(status.toString()) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } int shuffleId = request.getShuffleId(); int partitionId = request.getPartitionId(); int partitionNumPerRange = request.getPartitionNumPerRange(); int partitionNum = request.getPartitionNum(); - StatusCode status = StatusCode.SUCCESS; String msg = "OK"; GetLocalShuffleIndexResponse reply; String requestInfo = @@ -928,6 +1036,17 @@ public void getMemoryShuffleData( GetMemoryShuffleDataRequest request, StreamObserver responseObserver) { String appId = request.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + GetMemoryShuffleDataResponse reply = + GetMemoryShuffleDataResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(status.toString()) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } int shuffleId = request.getShuffleId(); int partitionId = request.getPartitionId(); long blockId = request.getLastBlockId(); @@ -943,7 +1062,6 @@ public void getMemoryShuffleData( ShuffleServerGrpcMetrics.GET_MEMORY_SHUFFLE_DATA_METHOD, transportTime); } } - StatusCode status = StatusCode.SUCCESS; String msg = "OK"; GetMemoryShuffleDataResponse reply; String requestInfo = @@ -1108,4 +1226,12 @@ private List toShuffleDataBlockSegments( } return shuffleDataBlockSegments; } + + private StatusCode verifyRequest(String appId) { + if (StringUtils.isNotBlank(appId) + && shuffleServer.getShuffleTaskManager().isAppExpired(appId)) { + return StatusCode.NO_REGISTER; + } + return StatusCode.SUCCESS; + } } diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index 8fe597d03a..b258c8a119 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -725,12 +725,12 @@ public void checkResourceStatus() { } } - private boolean isAppExpired(String appId) { - if (shuffleTaskInfos.get(appId) == null) { + public boolean isAppExpired(String appId) { + ShuffleTaskInfo shuffleTaskInfo = shuffleTaskInfos.get(appId); + if (shuffleTaskInfo == null) { return true; } - return System.currentTimeMillis() - shuffleTaskInfos.get(appId).getCurrentTimes() - > appExpiredWithoutHB; + return System.currentTimeMillis() - shuffleTaskInfo.getCurrentTimes() > appExpiredWithoutHB; } /** diff --git a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java index 27a3f1dc48..cca6a3935b 100644 --- a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java +++ b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java @@ -28,6 +28,7 @@ import io.netty.channel.ChannelFutureListener; import org.apache.commons.collections.MapUtils; import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -69,7 +70,6 @@ public class ShuffleServerNettyHandler implements BaseMessageHandler { private static final Logger LOG = LoggerFactory.getLogger(ShuffleServerNettyHandler.class); - private static final int RPC_TIMEOUT = 60000; private final ShuffleServer shuffleServer; public ShuffleServerNettyHandler(ShuffleServer shuffleServer) { @@ -335,6 +335,18 @@ private static void releaseNettyBufferAndMetrics( public void handleGetMemoryShuffleDataRequest( TransportClient client, GetMemoryShuffleDataRequest req) { String appId = req.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + GetMemoryShuffleDataResponse response = + new GetMemoryShuffleDataResponse( + req.getRequestId(), + status, + status.toString(), + Lists.newArrayList(), + Unpooled.EMPTY_BUFFER); + client.getChannel().writeAndFlush(response); + return; + } int shuffleId = req.getShuffleId(); int partitionId = req.getPartitionId(); long blockId = req.getLastBlockId(); @@ -349,7 +361,6 @@ public void handleGetMemoryShuffleDataRequest( .recordTransportTime(GetMemoryShuffleDataRequest.class.getName(), transportTime); } } - StatusCode status = StatusCode.SUCCESS; String msg = "OK"; GetMemoryShuffleDataResponse response; String requestInfo = @@ -417,11 +428,18 @@ public void handleGetMemoryShuffleDataRequest( public void handleGetLocalShuffleIndexRequest( TransportClient client, GetLocalShuffleIndexRequest req) { String appId = req.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + GetLocalShuffleIndexResponse response = + new GetLocalShuffleIndexResponse( + req.getRequestId(), status, status.toString(), Unpooled.EMPTY_BUFFER, 0L); + client.getChannel().writeAndFlush(response); + return; + } int shuffleId = req.getShuffleId(); int partitionId = req.getPartitionId(); int partitionNumPerRange = req.getPartitionNumPerRange(); int partitionNum = req.getPartitionNum(); - StatusCode status = StatusCode.SUCCESS; String msg = "OK"; GetLocalShuffleIndexResponse response; String requestInfo = @@ -501,7 +519,19 @@ public void handleGetLocalShuffleIndexRequest( } public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDataRequest req) { + GetLocalShuffleDataResponse response; String appId = req.getAppId(); + StatusCode status = verifyRequest(appId); + if (status != StatusCode.SUCCESS) { + response = + new GetLocalShuffleDataResponse( + req.getRequestId(), + status, + status.toString(), + new NettyManagedBuffer(Unpooled.EMPTY_BUFFER)); + client.getChannel().writeAndFlush(response); + return; + } int shuffleId = req.getShuffleId(); int partitionId = req.getPartitionId(); int partitionNumPerRange = req.getPartitionNumPerRange(); @@ -519,9 +549,7 @@ public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDat } String storageType = shuffleServer.getShuffleServerConf().get(RssBaseConf.RSS_STORAGE_TYPE).name(); - StatusCode status = StatusCode.SUCCESS; String msg = "OK"; - GetLocalShuffleDataResponse response; String requestInfo = "appId[" + appId @@ -625,6 +653,14 @@ private ShufflePartitionedBlock[] toPartitionedBlock(List bloc return ret; } + private StatusCode verifyRequest(String appId) { + if (StringUtils.isNotBlank(appId) + && shuffleServer.getShuffleTaskManager().isAppExpired(appId)) { + return StatusCode.NO_REGISTER; + } + return StatusCode.SUCCESS; + } + class ReleaseMemoryAndRecordReadTimeListener implements ChannelFutureListener { private final long readStartedTime; private final long readBufferSize;