From d170004a32418a32536aa792b198dfd714b5ea8c Mon Sep 17 00:00:00 2001 From: zhengchenyu Date: Wed, 11 Sep 2024 10:46:18 +0800 Subject: [PATCH] [#1749] feat(remote merge): Introduce new reader for reading sorted data. (#2034) ### What changes were proposed in this pull request? Introduce new reader for reading sorted data. Since #1748 already provides methods for merging blocks, we need to provide a method for reading merged block. The record obtained from the getSortedShuffleData method is sorted using the comparatorClassName which is passed by registerShuffle. ### Why are the changes needed? Fix: #1749 ### Does this PR introduce _any_ user-facing change? Yes, add doc in a separated issue ### How was this patch tested? unit test, integration test, test real job in cluster. --- .../mapred/SortWriteBufferManagerTest.java | 15 +- .../mapreduce/task/reduce/FetcherTest.java | 15 +- .../manager/RssShuffleManagerBase.java | 7 +- .../sort/buffer/WriteBufferManagerTest.java | 15 +- client/pom.xml | 6 + .../client/api/ShuffleWriteClient.java | 21 +- .../client/impl/ShuffleWriteClientImpl.java | 66 +- .../apache/uniffle/client/record/Record.java | 69 ++ .../uniffle/client/record/RecordBlob.java | 110 +++ .../uniffle/client/record/RecordBuffer.java | 112 +++ .../client/record/RecordCollection.java | 31 + .../record/metrics/MetricsReporter.java | 22 + .../client/record/reader/BufferedSegment.java | 61 ++ .../client/record/reader/KeyValueReader.java | 29 + .../client/record/reader/KeyValuesReader.java | 29 + .../client/record/reader/RMRecordsReader.java | 616 ++++++++++++ .../client/record/writer/Combiner.java | 38 + .../record/reader/BufferedSegmentTest.java | 99 ++ .../reader/MockedShuffleServerClient.java | 203 ++++ .../reader/MockedShuffleWriteClient.java | 182 ++++ .../record/reader/RMRecordsReaderTest.java | 335 +++++++ .../record/writer/RecordCollectionTest.java | 158 +++ .../record/writer/SumByKeyCombiner.java | 123 +++ .../uniffle/common/config/RssClientConf.java | 25 + .../common/merger/StreamedSegment.java | 20 +- .../BufferPartialInputStreamImpl.java | 62 ++ .../common/serializer/PartialInputStream.java | 34 + .../serializer/PartialInputStreamImpl.java | 32 - .../SeekableInMemoryByteChannel.java | 165 ---- .../uniffle/common/merger/MergerTest.java | 8 +- .../records/RecordsReaderWriterTest.java | 99 +- .../serializer/PartialInputStreamTest.java | 60 +- .../common/serializer/SerializerUtils.java | 4 +- .../serializer/WritableSerializerTest.java | 45 +- .../RemoteMergeShuffleWithRssClientTest.java | 917 +++++++++++++++++ ...leWithRssClientTestWhenShuffleFlushed.java | 924 ++++++++++++++++++ .../client/api/ShuffleServerClient.java | 8 + .../impl/grpc/ShuffleServerGrpcClient.java | 146 ++- .../RssGetSortedShuffleDataRequest.java | 50 + .../request/RssRegisterShuffleRequest.java | 58 +- .../request/RssStartSortMergeRequest.java | 52 + .../RssGetSortedShuffleDataResponse.java | 49 + .../response/RssStartSortMergeResponse.java | 27 + proto/src/main/proto/Rss.proto | 33 + .../server/ShuffleServerGrpcService.java | 271 ++++- .../merge/BlockFlushFileReaderTest.java | 7 +- .../server/merge/MergedResultTest.java | 5 +- .../server/merge/ShuffleMergeManagerTest.java | 7 +- 48 files changed, 5133 insertions(+), 337 deletions(-) create mode 100644 client/src/main/java/org/apache/uniffle/client/record/Record.java create mode 100644 client/src/main/java/org/apache/uniffle/client/record/RecordBlob.java create mode 100644 client/src/main/java/org/apache/uniffle/client/record/RecordBuffer.java create mode 100644 client/src/main/java/org/apache/uniffle/client/record/RecordCollection.java create mode 100644 client/src/main/java/org/apache/uniffle/client/record/metrics/MetricsReporter.java create mode 100644 client/src/main/java/org/apache/uniffle/client/record/reader/BufferedSegment.java create mode 100644 client/src/main/java/org/apache/uniffle/client/record/reader/KeyValueReader.java create mode 100644 client/src/main/java/org/apache/uniffle/client/record/reader/KeyValuesReader.java create mode 100644 client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java create mode 100644 client/src/main/java/org/apache/uniffle/client/record/writer/Combiner.java create mode 100644 client/src/test/java/org/apache/uniffle/client/record/reader/BufferedSegmentTest.java create mode 100644 client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleServerClient.java create mode 100644 client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java create mode 100644 client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java create mode 100644 client/src/test/java/org/apache/uniffle/client/record/writer/RecordCollectionTest.java create mode 100644 client/src/test/java/org/apache/uniffle/client/record/writer/SumByKeyCombiner.java create mode 100644 common/src/main/java/org/apache/uniffle/common/serializer/BufferPartialInputStreamImpl.java delete mode 100644 common/src/main/java/org/apache/uniffle/common/serializer/SeekableInMemoryByteChannel.java create mode 100644 integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java create mode 100644 integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java create mode 100644 internal-client/src/main/java/org/apache/uniffle/client/request/RssGetSortedShuffleDataRequest.java create mode 100644 internal-client/src/main/java/org/apache/uniffle/client/request/RssStartSortMergeRequest.java create mode 100644 internal-client/src/main/java/org/apache/uniffle/client/response/RssGetSortedShuffleDataResponse.java create mode 100644 internal-client/src/main/java/org/apache/uniffle/client/response/RssStartSortMergeResponse.java diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java index c0da9ec977..284e16f030 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java @@ -518,7 +518,12 @@ public void registerShuffle( RemoteStorageInfo remoteStorage, ShuffleDataDistributionType distributionType, int maxConcurrencyPerPartitionToWrite, - int stageAttemptNumber) {} + int stageAttemptNumber, + String keyClassName, + String valueClassName, + String comparatorClassName, + int mergedBlockSize, + String mergeClassLoader) {} @Override public boolean sendCommit( @@ -613,6 +618,14 @@ public void unregisterShuffle(String appId, int shuffleId) {} @Override public void unregisterShuffle(String appId) {} + + @Override + public void startSortMerge( + Set serverInfos, + String appId, + int shuffleId, + int partitionId, + Roaring64NavigableMap expectedTaskIds) {} } static class Reduce extends MapReduceBase diff --git a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java index 7a58ff45d6..3569bca8d3 100644 --- a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java +++ b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java @@ -506,7 +506,12 @@ public void registerShuffle( RemoteStorageInfo storageType, ShuffleDataDistributionType distributionType, int maxConcurrencyPerPartitionToWrite, - int stageAttemptNumber) {} + int stageAttemptNumber, + String keyClassName, + String valueClassName, + String comparatorClassName, + int mergedBlockSize, + String mergeClassLoader) {} @Override public boolean sendCommit( @@ -582,6 +587,14 @@ public void unregisterShuffle(String appId, int shuffleId) {} @Override public void unregisterShuffle(String appId) {} + + @Override + public void startSortMerge( + Set serverInfos, + String appId, + int shuffleId, + int partitionId, + Roaring64NavigableMap expectedTaskIds) {} } static class MockedShuffleReadClient implements ShuffleReadClient { diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index 077ffe9d92..767bb03ea5 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -1021,7 +1021,12 @@ protected void registerShuffleServers( remoteStorage, ShuffleDataDistributionType.NORMAL, maxConcurrencyPerPartitionToWrite, - stageAttemptNumber); + stageAttemptNumber, + null, + null, + null, + -1, + null); }); LOG.info( "Finish register shuffleId {} with {} ms", shuffleId, (System.currentTimeMillis() - start)); diff --git a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java index f6ddad7cf5..dbe40fd06f 100644 --- a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java +++ b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java @@ -589,7 +589,12 @@ public void registerShuffle( RemoteStorageInfo remoteStorage, ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, - int stageAttemptNumber) {} + int stageAttemptNumber, + String keyClassName, + String valueClassName, + String comparatorClassName, + int mergedBlockSize, + String mergeClassLoader) {} @Override public boolean sendCommit( @@ -684,5 +689,13 @@ public void unregisterShuffle(String appId, int shuffleId) {} @Override public void unregisterShuffle(String appId) {} + + @Override + public void startSortMerge( + Set serverInfos, + String appId, + int shuffleId, + int partitionId, + Roaring64NavigableMap expectedTaskIds) {} } } diff --git a/client/pom.xml b/client/pom.xml index 072a0fb811..c33153750e 100644 --- a/client/pom.xml +++ b/client/pom.xml @@ -76,5 +76,11 @@ + + org.apache.uniffle + rss-common + test-jar + test + diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java index 548f348f5a..121271e361 100644 --- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java +++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java @@ -71,7 +71,12 @@ default void registerShuffle( remoteStorage, dataDistributionType, maxConcurrencyPerPartitionToWrite, - 0); + 0, + null, + null, + null, + -1, + null); } void registerShuffle( @@ -82,7 +87,12 @@ void registerShuffle( RemoteStorageInfo remoteStorage, ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, - int stageAttemptNumber); + int stageAttemptNumber, + String keyClassName, + String valueClassName, + String comparatorClassName, + int mergedBlockSize, + String mergeClassLoader); boolean sendCommit( Set shuffleServerInfoSet, String appId, int shuffleId, int numMaps); @@ -184,4 +194,11 @@ Roaring64NavigableMap getShuffleResultForMultiPart( void unregisterShuffle(String appId, int shuffleId); void unregisterShuffle(String appId); + + void startSortMerge( + Set serverInfos, + String appId, + int shuffleId, + int partitionId, + Roaring64NavigableMap expectedTaskIds); } diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index 2178f365b4..6f2860c11f 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -64,6 +64,7 @@ import org.apache.uniffle.client.request.RssReportShuffleResultRequest; import org.apache.uniffle.client.request.RssSendCommitRequest; import org.apache.uniffle.client.request.RssSendShuffleDataRequest; +import org.apache.uniffle.client.request.RssStartSortMergeRequest; import org.apache.uniffle.client.request.RssUnregisterShuffleByAppIdRequest; import org.apache.uniffle.client.request.RssUnregisterShuffleRequest; import org.apache.uniffle.client.response.ClientResponse; @@ -75,6 +76,7 @@ import org.apache.uniffle.client.response.RssReportShuffleResultResponse; import org.apache.uniffle.client.response.RssSendCommitResponse; import org.apache.uniffle.client.response.RssSendShuffleDataResponse; +import org.apache.uniffle.client.response.RssStartSortMergeResponse; import org.apache.uniffle.client.response.RssUnregisterShuffleByAppIdResponse; import org.apache.uniffle.client.response.RssUnregisterShuffleResponse; import org.apache.uniffle.client.response.SendShuffleDataResult; @@ -561,7 +563,12 @@ public void registerShuffle( RemoteStorageInfo remoteStorage, ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, - int stageAttemptNumber) { + int stageAttemptNumber, + String keyClassName, + String valueClassName, + String comparatorClassName, + int mergedBlockSize, + String mergeClassLoader) { String user = null; try { user = UserGroupInformation.getCurrentUser().getShortUserName(); @@ -578,7 +585,12 @@ public void registerShuffle( user, dataDistributionType, maxConcurrencyPerPartitionToWrite, - stageAttemptNumber); + stageAttemptNumber, + keyClassName, + valueClassName, + comparatorClassName, + mergedBlockSize, + mergeClassLoader); RssRegisterShuffleResponse response = getShuffleServerClient(shuffleServerInfo).registerShuffle(request); @@ -1069,6 +1081,56 @@ public void unregisterShuffle(String appId) { } } + @Override + public void startSortMerge( + Set serverInfos, + String appId, + int shuffleId, + int partitionId, + Roaring64NavigableMap expectedBlockIds) { + RssStartSortMergeRequest request = + new RssStartSortMergeRequest(appId, shuffleId, partitionId, expectedBlockIds); + boolean atLeastOneSucceeful = false; + for (ShuffleServerInfo ssi : serverInfos) { + RssStartSortMergeResponse response = getShuffleServerClient(ssi).startSortMerge(request); + if (response.getStatusCode() == StatusCode.SUCCESS) { + atLeastOneSucceeful = true; + LOG.info( + "Report unique blocks to " + + ssi + + " for appId[" + + appId + + "], shuffleId[" + + shuffleId + + "], partitionIds[" + + partitionId + + "] successfully"); + } else { + LOG.warn( + "Report unique blocks to " + + ssi + + " for appId[" + + appId + + "], shuffleId[" + + shuffleId + + "], partitionIds[" + + partitionId + + "] failed with " + + response.getStatusCode()); + } + } + if (!atLeastOneSucceeful) { + throw new RssFetchFailedException( + "Report Unique Blocks failed for appId[" + + appId + + "], shuffleId[" + + shuffleId + + "], partitionIds[" + + partitionId + + "]"); + } + } + private void throwExceptionIfNecessary(ClientResponse response, String errorMsg) { if (response != null && response.getStatusCode() != StatusCode.SUCCESS) { LOG.error(errorMsg); diff --git a/client/src/main/java/org/apache/uniffle/client/record/Record.java b/client/src/main/java/org/apache/uniffle/client/record/Record.java new file mode 100644 index 0000000000..c6eda51669 --- /dev/null +++ b/client/src/main/java/org/apache/uniffle/client/record/Record.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record; + +import com.google.common.base.Objects; + +public class Record { + + private K key; + private V value; + + private Record(K key, V value) { + this.key = key; + this.value = value; + } + + public static Record create(K key, V value) { + return new Record(key, value); + } + + public K getKey() { + return key; + } + + public V getValue() { + return value; + } + + public void setValue(V value) { + this.value = value; + } + + @Override + public String toString() { + return "Record{" + "key=" + key + ", value=" + value + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Record record = (Record) o; + return Objects.equal(key, record.key) && Objects.equal(value, record.value); + } + + @Override + public int hashCode() { + return Objects.hashCode(key, value); + } +} diff --git a/client/src/main/java/org/apache/uniffle/client/record/RecordBlob.java b/client/src/main/java/org/apache/uniffle/client/record/RecordBlob.java new file mode 100644 index 0000000000..3a5e7c9e4c --- /dev/null +++ b/client/src/main/java/org/apache/uniffle/client/record/RecordBlob.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.uniffle.client.record.writer.Combiner; +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.records.RecordsWriter; + +/* + * RecordBlob is used to store records. The records are stored in the form + * of LinkedHashMap and is mainly used for combine. + * */ +public class RecordBlob implements RecordCollection { + + private final int partitionId; + + private int size = 0; + // We can not decide the type of record value. If map combine is enabled, + // the type of value is C, otherwise it is V. + private LinkedHashMap>> records = new LinkedHashMap<>(); + private List> result = new ArrayList<>(); + + public RecordBlob(int partitionId) { + this.partitionId = partitionId; + } + + public void addRecords(RecordBuffer recordBuffer) { + List> recordList = recordBuffer.getRecords(); + for (Record record : recordList) { + K key = record.getKey(); + if (!records.containsKey(key)) { + records.put(key, new ArrayList<>()); + } + this.records.get(key).add(record); + this.size++; + } + } + + public void addRecord(K key, V value) { + if (!records.containsKey(key)) { + records.put(key, new ArrayList<>()); + } + this.records.get(key).add(Record.create(key, value)); + this.size++; + } + + public void combine(Combiner combiner, boolean isMapCombined) { + if (combiner == null) { + throw new RssException("combiner is not set"); + } + if (isMapCombined) { + this.result = combiner.combineCombiners(records.entrySet().iterator()); + } else { + this.result = combiner.combineValues(records.entrySet().iterator()); + } + records.clear(); + } + + public void serialize(RecordsWriter writer) throws IOException { + for (Record record : result) { + writer.append(record.getKey(), record.getValue()); + } + } + + public void clear() { + this.size = 0; + this.records.clear(); + this.result.clear(); + } + + public int getPartitionId() { + return partitionId; + } + + @Override + public int size() { + return this.size; + } + + @VisibleForTesting + public LinkedHashMap>> getRecords() { + return records; + } + + public List> getResult() { + return result; + } +} diff --git a/client/src/main/java/org/apache/uniffle/client/record/RecordBuffer.java b/client/src/main/java/org/apache/uniffle/client/record/RecordBuffer.java new file mode 100644 index 0000000000..02032fcf65 --- /dev/null +++ b/client/src/main/java/org/apache/uniffle/client/record/RecordBuffer.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.records.RecordsWriter; + +/* + * RecordBuffer is used to store records. The records are stored in the form of List. + * It can quickly index to record and supports sorting. + * */ +public class RecordBuffer implements RecordCollection { + + private final int partitionId; + + private int size = 0; + private List> records = new ArrayList<>(); + + public RecordBuffer(int partitionId) { + this.partitionId = partitionId; + } + + public void addRecord(K key, V value) { + Record record; + record = Record.create(key, value); + this.records.add(record); + this.size++; + } + + public void addRecord(Record record) { + this.records.add(record); + this.size++; + } + + public void addRecords(List> records) { + this.records.addAll(records); + this.size += records.size(); + } + + public List> getRecords() { + return records; + } + + public void sort(Comparator comparator) { + if (comparator == null) { + throw new RssException("comparator is not set"); + } + this.records.sort( + new Comparator() { + @Override + public int compare(Record o1, Record o2) { + return comparator.compare(o1.getKey(), o2.getKey()); + } + }); + } + + public void serialize(RecordsWriter writer) throws IOException { + for (Record record : records) { + writer.append(record.getKey(), record.getValue()); + } + } + + public void clear() { + this.size = 0; + this.records.clear(); + } + + public int getPartitionId() { + return partitionId; + } + + @Override + public int size() { + return this.size; + } + + public K getKey(int index) { + return this.records.get(index).getKey(); + } + + public V getValue(int index) { + return this.records.get(index).getValue(); + } + + public K getLastKey() { + return this.records.get(this.records.size() - 1).getKey(); + } + + public K getFirstKey() { + return this.records.get(0).getKey(); + } +} diff --git a/client/src/main/java/org/apache/uniffle/client/record/RecordCollection.java b/client/src/main/java/org/apache/uniffle/client/record/RecordCollection.java new file mode 100644 index 0000000000..e6c6f4f630 --- /dev/null +++ b/client/src/main/java/org/apache/uniffle/client/record/RecordCollection.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record; + +import java.io.IOException; + +import org.apache.uniffle.common.records.RecordsWriter; + +public interface RecordCollection { + + void addRecord(K key, V value); + + void serialize(RecordsWriter writer) throws IOException; + + int size(); +} diff --git a/client/src/main/java/org/apache/uniffle/client/record/metrics/MetricsReporter.java b/client/src/main/java/org/apache/uniffle/client/record/metrics/MetricsReporter.java new file mode 100644 index 0000000000..bd63e39ee4 --- /dev/null +++ b/client/src/main/java/org/apache/uniffle/client/record/metrics/MetricsReporter.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record.metrics; + +public interface MetricsReporter { + void incRecordsRead(long v); +} diff --git a/client/src/main/java/org/apache/uniffle/client/record/reader/BufferedSegment.java b/client/src/main/java/org/apache/uniffle/client/record/reader/BufferedSegment.java new file mode 100644 index 0000000000..d7f2afe8a1 --- /dev/null +++ b/client/src/main/java/org/apache/uniffle/client/record/reader/BufferedSegment.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record.reader; + +import java.io.IOException; + +import org.apache.uniffle.client.record.RecordBuffer; +import org.apache.uniffle.common.merger.Segment; + +public class BufferedSegment extends Segment { + + private RecordBuffer recordBuffer; + private int index = -1; + + public BufferedSegment(RecordBuffer recordBuffer) { + super(recordBuffer.getPartitionId()); + this.recordBuffer = recordBuffer; + } + + @Override + public boolean next() throws IOException { + boolean hasNext = index < this.recordBuffer.size() - 1; + if (hasNext) { + index++; + } + return hasNext; + } + + @Override + public Object getCurrentKey() { + return this.recordBuffer.getKey(index); + } + + @Override + public Object getCurrentValue() { + return this.recordBuffer.getValue(index); + } + + @Override + public void close() throws IOException { + if (recordBuffer != null) { + this.recordBuffer.clear(); + this.recordBuffer = null; + } + } +} diff --git a/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValueReader.java b/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValueReader.java new file mode 100644 index 0000000000..606e97ac48 --- /dev/null +++ b/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValueReader.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record.reader; + +import java.io.IOException; + +public abstract class KeyValueReader { + + public abstract boolean next() throws IOException; + + public abstract K getCurrentKey() throws IOException; + + public abstract V getCurrentValue() throws IOException; +} diff --git a/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValuesReader.java b/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValuesReader.java new file mode 100644 index 0000000000..2a8db0faef --- /dev/null +++ b/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValuesReader.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record.reader; + +import java.io.IOException; + +public abstract class KeyValuesReader { + + public abstract boolean next() throws IOException; + + public abstract K getCurrentKey() throws IOException; + + public abstract Iterable getCurrentValues() throws IOException; +} diff --git a/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java b/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java new file mode 100644 index 0000000000..30290a08f7 --- /dev/null +++ b/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java @@ -0,0 +1,616 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record.reader; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.hadoop.io.DataInputBuffer; +import org.apache.hadoop.io.RawComparator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.client.api.ShuffleServerClient; +import org.apache.uniffle.client.factory.ShuffleServerClientFactory; +import org.apache.uniffle.client.record.Record; +import org.apache.uniffle.client.record.RecordBlob; +import org.apache.uniffle.client.record.RecordBuffer; +import org.apache.uniffle.client.record.metrics.MetricsReporter; +import org.apache.uniffle.client.record.writer.Combiner; +import org.apache.uniffle.client.request.RssGetSortedShuffleDataRequest; +import org.apache.uniffle.client.response.RssGetSortedShuffleDataResponse; +import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.merger.MergeState; +import org.apache.uniffle.common.merger.Merger; +import org.apache.uniffle.common.records.RecordsReader; +import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.Serializer; +import org.apache.uniffle.common.serializer.SerializerFactory; +import org.apache.uniffle.common.serializer.SerializerInstance; +import org.apache.uniffle.common.serializer.writable.ComparativeOutputBuffer; +import org.apache.uniffle.common.util.JavaUtils; + +import static org.apache.uniffle.client.util.RssClientConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE; +import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_MERGE_FETCH_INIT_SLEEP_MS; +import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_MERGE_FETCH_MAX_SLEEP_MS; +import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_MERGE_READER_MAX_BUFFER; +import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_MERGE_READER_MAX_RECORDS_PER_BUFFER; + +public class RMRecordsReader { + + private static final Logger LOG = LoggerFactory.getLogger(RMRecordsReader.class); + + private String appId; + private final int shuffleId; + private final Set partitionIds; + private final RssConf rssConf; + private final Class keyClass; + private final Class valueClass; + private final Comparator comparator; + private boolean raw; + private final Combiner combiner; + private boolean isMapCombine; + private final MetricsReporter metrics; + private SerializerInstance serializerInstance; + + private final long initFetchSleepTime; + private final long maxFetchSleepTime; + private final int maxBufferPerPartition; + private final int maxRecordsNumPerBuffer; + + private Map> shuffleServerInfoMap; + private volatile boolean stop = false; + private volatile String errorMessage = null; + + private Map> combineBuffers = JavaUtils.newConcurrentMap(); + private Map> mergeBuffers = JavaUtils.newConcurrentMap(); + private Queue results; + + public RMRecordsReader( + String appId, + int shuffleId, + Set partitionIds, + Map> shuffleServerInfoMap, + RssConf rssConf, + Class keyClass, + Class valueClass, + Comparator comparator, + boolean raw, + Combiner combiner, + boolean isMapCombine, + MetricsReporter metrics) { + this.appId = appId; + this.shuffleId = shuffleId; + this.partitionIds = partitionIds; + this.shuffleServerInfoMap = shuffleServerInfoMap; + this.rssConf = rssConf; + this.keyClass = keyClass; + this.valueClass = valueClass; + this.raw = raw; + if (raw && comparator == null) { + throw new RssException("RawComparator must be set!"); + } + this.comparator = + comparator != null + ? comparator + : new Comparator() { + @Override + public int compare(K o1, K o2) { + int h1 = (o1 == null) ? 0 : o1.hashCode(); + int h2 = (o2 == null) ? 0 : o2.hashCode(); + return h1 < h2 ? -1 : h1 == h2 ? 0 : 1; + } + }; + this.combiner = combiner; + this.isMapCombine = isMapCombine; + this.metrics = metrics; + if (this.raw) { + SerializerFactory factory = new SerializerFactory(rssConf); + Serializer serializer = factory.getSerializer(keyClass); + assert factory.getSerializer(valueClass).getClass().equals(serializer.getClass()); + this.serializerInstance = serializer.newInstance(); + } + + this.initFetchSleepTime = rssConf.get(RSS_CLIENT_REMOTE_MERGE_FETCH_INIT_SLEEP_MS); + this.maxFetchSleepTime = rssConf.get(RSS_CLIENT_REMOTE_MERGE_FETCH_MAX_SLEEP_MS); + int maxBuffer = rssConf.get(RSS_CLIENT_REMOTE_MERGE_READER_MAX_BUFFER); + this.maxBufferPerPartition = Math.max(1, maxBuffer / partitionIds.size()); + this.maxRecordsNumPerBuffer = + rssConf.get(RSS_CLIENT_REMOTE_MERGE_READER_MAX_RECORDS_PER_BUFFER); + this.results = new Queue(maxBufferPerPartition * maxRecordsNumPerBuffer * partitionIds.size()); + LOG.info("RMRecordsReader constructed for partitions {}", partitionIds); + } + + public void start() { + for (int partitionId : partitionIds) { + mergeBuffers.put(partitionId, new Queue(maxBufferPerPartition)); + if (this.combiner != null) { + combineBuffers.put(partitionId, new Queue(maxBufferPerPartition)); + } + RecordsFetcher fetcher = new RecordsFetcher(partitionId); + fetcher.start(); + if (this.combiner != null) { + RecordsCombiner combineThread = new RecordsCombiner(partitionId); + combineThread.start(); + } + } + + RecordsMerger recordMerger = new RecordsMerger(); + recordMerger.start(); + } + + public void close() { + errorMessage = null; + stop = true; + for (Queue buffer : mergeBuffers.values()) { + buffer.clear(); + } + mergeBuffers.clear(); + if (combiner != null) { + for (Queue buffer : combineBuffers.values()) { + buffer.clear(); + } + combineBuffers.clear(); + } + if (results != null) { + this.results.clear(); + this.results = null; + } + } + + private boolean isSameKey(Object k1, Object k2) { + if (raw) { + ComparativeOutputBuffer buffer1 = (ComparativeOutputBuffer) k1; + ComparativeOutputBuffer buffer2 = (ComparativeOutputBuffer) k2; + return ((RawComparator) this.comparator) + .compare( + buffer1.getData(), + 0, + buffer1.getLength(), + buffer2.getData(), + 0, + buffer2.getLength()) + == 0; + } else { + return this.comparator.compare(k1, k2) == 0; + } + } + + public KeyValueReader rawKeyValueReader() { + if (!raw) { + throw new RssException("rawKeyValueReader is not supported!"); + } + return new KeyValueReader() { + + private Record curr = null; + + @Override + public boolean next() throws IOException { + try { + curr = results.take(); + return curr != null; + } catch (InterruptedException e) { + throw new IOException(e); + } + } + + @Override + public ComparativeOutputBuffer getCurrentKey() throws IOException { + return curr.getKey(); + } + + @Override + public ComparativeOutputBuffer getCurrentValue() throws IOException { + return curr.getValue(); + } + }; + } + + public KeyValueReader keyValueReader() { + return new KeyValueReader() { + + private Record curr = null; + + @Override + public boolean next() throws IOException { + try { + curr = results.take(); + return curr != null; + } catch (InterruptedException e) { + throw new IOException(e); + } + } + + @Override + public K getCurrentKey() throws IOException { + if (raw) { + ComparativeOutputBuffer keyBuffer = (ComparativeOutputBuffer) curr.getKey(); + DataInputBuffer keyInputBuffer = new DataInputBuffer(); + keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength()); + return serializerInstance.deserialize(keyInputBuffer, keyClass); + } else { + return curr.getKey(); + } + } + + @Override + public C getCurrentValue() throws IOException { + if (raw) { + ComparativeOutputBuffer valueBuffer = (ComparativeOutputBuffer) curr.getValue(); + DataInputBuffer valueInputBuffer = new DataInputBuffer(); + valueInputBuffer.reset(valueBuffer.getData(), 0, valueBuffer.getLength()); + return serializerInstance.deserialize(valueInputBuffer, valueClass); + } else { + return curr.getValue(); + } + } + }; + } + + public KeyValuesReader keyValuesReader() { + return new KeyValuesReader() { + + private Record start = null; + + @Override + public boolean next() throws IOException { + try { + if (start == null) { + start = results.take(); + return start != null; + } else { + return true; + } + } catch (InterruptedException e) { + throw new IOException(e); + } + } + + @Override + public K getCurrentKey() throws IOException { + if (raw) { + ComparativeOutputBuffer keyBuffer = (ComparativeOutputBuffer) start.getKey(); + DataInputBuffer keyInputBuffer = new DataInputBuffer(); + keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength()); + return serializerInstance.deserialize(keyInputBuffer, keyClass); + } else { + return start.getKey(); + } + } + + @Override + public Iterable getCurrentValues() throws IOException { + return new Iterable() { + @Override + public Iterator iterator() { + + return new Iterator() { + + Record curr = start; + + @Override + public boolean hasNext() { + if (curr != null && isSameKey(curr.getKey(), start.getKey())) { + return true; + } else { + start = curr; + return false; + } + } + + @Override + public C next() { + try { + C ret; + if (raw) { + ComparativeOutputBuffer valueBuffer = (ComparativeOutputBuffer) curr.getValue(); + DataInputBuffer valueInputBuffer = new DataInputBuffer(); + valueInputBuffer.reset(valueBuffer.getData(), 0, valueBuffer.getLength()); + ret = serializerInstance.deserialize(valueInputBuffer, valueClass); + } else { + ret = curr.getValue(); + } + curr = results.take(); + return ret; + } catch (InterruptedException | IOException e) { + throw new RssException(e); + } + } + }; + } + }; + } + }; + } + + class Queue { + + private LinkedBlockingQueue queue; + private volatile boolean producerDone = false; + + Queue(int maxBufferPerPartition) { + this.queue = new LinkedBlockingQueue(maxBufferPerPartition); + } + + public void setProducerDone(boolean producerDone) { + this.producerDone = producerDone; + } + + public void put(E recordBuffer) throws InterruptedException { + this.queue.put(recordBuffer); + } + + // Block until data arrives or the producer completes the work. + // If null is returned, it means that all data has been processed + public E take() throws InterruptedException { + while (!producerDone && !stop) { + E e = this.queue.poll(100, TimeUnit.MILLISECONDS); + if (e != null) { + return e; + } + } + if (errorMessage != null) { + throw new RssException("RMShuffleReader fetch record failed, caused by " + errorMessage); + } + return this.queue.poll(100, TimeUnit.MILLISECONDS); + } + + public void clear() { + this.queue.clear(); + this.producerDone = false; + } + } + + class RecordsFetcher extends Thread { + + private int partitionId; + private long sleepTime; + private long blockId = 1; // Merged blockId counting from 1 + private RecordBuffer recordBuffer; + private Queue nextQueue; + private List serverInfos; + private ShuffleServerClient client; + private int choose; + private String fetchError; + + RecordsFetcher(int partitionId) { + this.partitionId = partitionId; + this.sleepTime = initFetchSleepTime; + this.recordBuffer = new RecordBuffer(partitionId); + this.nextQueue = + combiner == null ? mergeBuffers.get(partitionId) : combineBuffers.get(partitionId); + this.serverInfos = shuffleServerInfoMap.get(partitionId); + this.choose = serverInfos.size() - 1; + this.client = createShuffleServerClient(serverInfos.get(choose)); + setName("RecordsFetcher-" + partitionId); + } + + private void nextShuffleServerInfo() { + if (this.choose <= 0) { + throw new RssException("Fetch sorted record failed, last error message is " + fetchError); + } + choose--; + this.client = createShuffleServerClient(serverInfos.get(choose)); + } + + @Override + public void run() { + while (!stop) { + try { + RssGetSortedShuffleDataRequest request = + new RssGetSortedShuffleDataRequest(appId, shuffleId, partitionId, blockId); + RssGetSortedShuffleDataResponse response = client.getSortedShuffleData(request); + if (response.getStatusCode() != StatusCode.SUCCESS + || response.getMergeState() == MergeState.INTERNAL_ERROR.code()) { + fetchError = response.getMessage(); + nextShuffleServerInfo(); + break; + } else if (response.getMergeState() == MergeState.INITED.code()) { + fetchError = "Remote merge should be started!"; + nextShuffleServerInfo(); + break; + } + if (response.getMergeState() == MergeState.MERGING.code() + && response.getNextBlockId() == -1) { + // All merged data has been read, but there may be data that has not yet been merged. So + // wait done! + LOG.info("RMRecordsFetcher will sleep {} ms", sleepTime); + Thread.sleep(this.sleepTime); + this.sleepTime = Math.min(this.sleepTime * 2, maxFetchSleepTime); + } else if (response.getMergeState() == MergeState.DONE.code() + && response.getNextBlockId() == -1) { + // All data has been read. Send the last records. + if (recordBuffer.size() > 0) { + nextQueue.put(recordBuffer); + } + nextQueue.setProducerDone(true); + break; + } else if (response.getMergeState() == MergeState.DONE.code() + || response.getMergeState() == MergeState.MERGING.code()) { + this.sleepTime = initFetchSleepTime; + ByteBuffer byteBuffer = response.getData(); + blockId = response.getNextBlockId(); + // Fetch blocks and parsing blocks are a synchronous process. If the two processes are + // split into two + // different threads, then will be asynchronous processes. Although it seems to save + // time, it actually + // consumes more memory. + RecordsReader reader = + new RecordsReader( + rssConf, + PartialInputStream.newInputStream(byteBuffer), + keyClass, + valueClass, + raw); + while (reader.next()) { + if (metrics != null) { + metrics.incRecordsRead(1); + } + if (recordBuffer.size() >= maxRecordsNumPerBuffer) { + nextQueue.put(recordBuffer); + recordBuffer = new RecordBuffer(partitionId); + } + recordBuffer.addRecord(reader.getCurrentKey(), reader.getCurrentValue()); + } + } else { + fetchError = "Receive wrong offset from server, offset is " + response.getNextBlockId(); + nextShuffleServerInfo(); + break; + } + } catch (Exception e) { + errorMessage = e.getMessage(); + stop = true; + LOG.info("Found exception when fetch sorted record, caused by ", e); + } + } + } + } + + class RecordsCombiner extends Thread { + + private int partitionId; + // The RecordBuffer has a capacity limit, records for the same key may be + // distributed in different RecordBuffers. So we need a cachedBuffer used + // to record the buffer of the last combine. + private RecordBuffer cached; + private Queue nextQueue; + + RecordsCombiner(int partitionId) { + this.partitionId = partitionId; + this.cached = new RecordBuffer(partitionId); + this.nextQueue = mergeBuffers.get(partitionId); + setName("RecordsCombiner-" + partitionId); + } + + @Override + public void run() { + while (!stop) { + try { + // 1 try to get RecordBuffer from RecordFetcher + RecordBuffer current = combineBuffers.get(partitionId).take(); + // current is null means that all upstream data has been read + if (current == null) { + if (cached.size() > 0) { + sendCachedBuffer(cached); + } + nextQueue.setProducerDone(true); + break; + } else { + // 2 If the last key of cached is not same with the first key of current, + // we can send the cached to downstream directly. + if (cached.size() > 0 && !isSameKey(cached.getLastKey(), current.getFirstKey())) { + sendCachedBuffer(cached); + cached = new RecordBuffer(partitionId); + } + + // 3 combine the current, then cache it. By this way, we can handle the specical case + // that next record + // buffer has same key in current. + RecordBlob recordBlob = new RecordBlob(partitionId); + recordBlob.addRecords(current); + recordBlob.combine(combiner, isMapCombine); + for (Object record : recordBlob.getResult()) { + if (cached.size() >= maxRecordsNumPerBuffer + && !isSameKey(((Record) record).getKey(), cached.getLastKey())) { + sendCachedBuffer(cached); + cached = new RecordBuffer<>(partitionId); + } + cached.addRecord((Record) record); + } + } + } catch (InterruptedException e) { + throw new RssException(e); + } + } + } + + private void sendCachedBuffer(RecordBuffer cachedBuffer) throws InterruptedException { + // Multiple records with the same key may span different recordbuffers. we were only combined + // within the same recordbuffer. So before send to downstream, we should combine the cached. + RecordBlob recordBlob = new RecordBlob(partitionId); + recordBlob.addRecords(cachedBuffer); + recordBlob.combine(combiner, true); + RecordBuffer recordBuffer = new RecordBuffer<>(partitionId); + recordBuffer.addRecords(recordBlob.getResult()); + nextQueue.put(recordBuffer); + } + } + + class RecordsMerger extends Thread { + + RecordsMerger() { + setName("RecordsMerger"); + } + + @Override + public void run() { + try { + List segments = new ArrayList<>(); + for (int partitionId : partitionIds) { + RecordBuffer recordBuffer = mergeBuffers.get(partitionId).take(); + if (recordBuffer != null) { + BufferedSegment resolvedSegment = new BufferedSegment(recordBuffer); + segments.add(resolvedSegment); + } + } + Merger.MergeQueue mergeQueue = + new Merger.MergeQueue(rssConf, segments, keyClass, valueClass, comparator, raw); + mergeQueue.init(); + mergeQueue.setPopSegmentHook( + pid -> { + try { + RecordBuffer recordBuffer = mergeBuffers.get(pid).take(); + if (recordBuffer == null) { + return null; + } + return new BufferedSegment(recordBuffer); + } catch (InterruptedException ex) { + throw new RssException(ex); + } + }); + while (!stop && mergeQueue.next()) { + results.put(Record.create(mergeQueue.getCurrentKey(), mergeQueue.getCurrentValue())); + } + if (!stop) { + results.setProducerDone(true); + } + } catch (InterruptedException | IOException e) { + errorMessage = e.getMessage(); + stop = true; + } + } + } + + @VisibleForTesting + public ShuffleServerClient createShuffleServerClient(ShuffleServerInfo shuffleServerInfo) { + return ShuffleServerClientFactory.getInstance() + .getShuffleServerClient(RSS_CLIENT_TYPE_DEFAULT_VALUE, shuffleServerInfo, rssConf); + } +} diff --git a/client/src/main/java/org/apache/uniffle/client/record/writer/Combiner.java b/client/src/main/java/org/apache/uniffle/client/record/writer/Combiner.java new file mode 100644 index 0000000000..80253c1122 --- /dev/null +++ b/client/src/main/java/org/apache/uniffle/client/record/writer/Combiner.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record.writer; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.apache.uniffle.client.record.Record; + +/* + * If shuffle requires sorting, the input records are sorted by key, so they can be combined sequentially. + * If shuffle does not require sorting. The input records are sorted according to the hashcode of the key. + * Therefore, the same keys may not be organized together, so the data cannot be obtained sequentially. + * So LinkedHashMap needs to be used. + * */ +public abstract class Combiner { + + public abstract List combineValues(Iterator>> recordIterator); + + public abstract List combineCombiners( + Iterator>> recordIterator); +} diff --git a/client/src/test/java/org/apache/uniffle/client/record/reader/BufferedSegmentTest.java b/client/src/test/java/org/apache/uniffle/client/record/reader/BufferedSegmentTest.java new file mode 100644 index 0000000000..c4e32b5722 --- /dev/null +++ b/client/src/test/java/org/apache/uniffle/client/record/reader/BufferedSegmentTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record.reader; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.junit.jupiter.api.Test; + +import org.apache.uniffle.client.record.RecordBuffer; +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.merger.Merger; +import org.apache.uniffle.common.merger.Segment; +import org.apache.uniffle.common.serializer.SerializerUtils; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class BufferedSegmentTest { + + private static final int RECORDS = 1000; + private static final int SEGMENTS = 4; + + @Test + public void testMergeResolvedSegmentWithHook() throws Exception { + RssConf rssConf = new RssConf(); + List segments = new ArrayList<>(); + Comparator comparator = new Text.Comparator(); + for (int i = 0; i < SEGMENTS; i++) { + if (i % 2 == 0) { + segments.add(genResolvedSegment(Text.class, IntWritable.class, i, i, 4, RECORDS / 2)); + } else { + segments.add(genResolvedSegment(Text.class, IntWritable.class, i, i, 4, RECORDS)); + } + } + Map newSegments = new HashMap<>(); + for (int i = 0; i < SEGMENTS; i++) { + if (i % 2 == 0) { + newSegments.put( + i, + genResolvedSegment(Text.class, IntWritable.class, i, i + RECORDS * 2, 4, RECORDS / 2)); + } else { + newSegments.put( + i, genResolvedSegment(Text.class, IntWritable.class, i, i + RECORDS * 4, 4, RECORDS)); + } + } + Merger.MergeQueue mergeQueue = + new Merger.MergeQueue(rssConf, segments, Text.class, IntWritable.class, comparator, false); + mergeQueue.init(); + mergeQueue.setPopSegmentHook( + id -> { + Segment newSegment = newSegments.get(id); + if (newSegment != null) { + newSegments.remove(id); + } + return newSegment; + }); + for (int i = 0; i < RECORDS * 8; i++) { + if ((i >= 4 * RECORDS) && (i % 2 == 0)) { + continue; + } + mergeQueue.next(); + assertEquals(SerializerUtils.genData(Text.class, i), mergeQueue.getCurrentKey()); + assertEquals(SerializerUtils.genData(IntWritable.class, i), mergeQueue.getCurrentValue()); + } + assertFalse(mergeQueue.next()); + } + + private static BufferedSegment genResolvedSegment( + Class keyClass, Class valueClass, int pid, int start, int interval, int length) { + RecordBuffer buffer = new RecordBuffer(pid); + for (int i = 0; i < length; i++) { + buffer.addRecord( + SerializerUtils.genData(keyClass, start + i * interval), + SerializerUtils.genData(valueClass, start + i * interval)); + } + return new BufferedSegment(buffer); + } +} diff --git a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleServerClient.java b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleServerClient.java new file mode 100644 index 0000000000..93da7c4c9d --- /dev/null +++ b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleServerClient.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record.reader; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.roaringbitmap.longlong.Roaring64NavigableMap; + +import org.apache.uniffle.client.api.ShuffleServerClient; +import org.apache.uniffle.client.request.RssAppHeartBeatRequest; +import org.apache.uniffle.client.request.RssFinishShuffleRequest; +import org.apache.uniffle.client.request.RssGetInMemoryShuffleDataRequest; +import org.apache.uniffle.client.request.RssGetShuffleDataRequest; +import org.apache.uniffle.client.request.RssGetShuffleIndexRequest; +import org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest; +import org.apache.uniffle.client.request.RssGetShuffleResultRequest; +import org.apache.uniffle.client.request.RssGetSortedShuffleDataRequest; +import org.apache.uniffle.client.request.RssRegisterShuffleRequest; +import org.apache.uniffle.client.request.RssReportShuffleResultRequest; +import org.apache.uniffle.client.request.RssSendCommitRequest; +import org.apache.uniffle.client.request.RssSendShuffleDataRequest; +import org.apache.uniffle.client.request.RssStartSortMergeRequest; +import org.apache.uniffle.client.request.RssUnregisterShuffleByAppIdRequest; +import org.apache.uniffle.client.request.RssUnregisterShuffleRequest; +import org.apache.uniffle.client.response.RssAppHeartBeatResponse; +import org.apache.uniffle.client.response.RssFinishShuffleResponse; +import org.apache.uniffle.client.response.RssGetInMemoryShuffleDataResponse; +import org.apache.uniffle.client.response.RssGetShuffleDataResponse; +import org.apache.uniffle.client.response.RssGetShuffleIndexResponse; +import org.apache.uniffle.client.response.RssGetShuffleResultResponse; +import org.apache.uniffle.client.response.RssGetSortedShuffleDataResponse; +import org.apache.uniffle.client.response.RssRegisterShuffleResponse; +import org.apache.uniffle.client.response.RssReportShuffleResultResponse; +import org.apache.uniffle.client.response.RssSendCommitResponse; +import org.apache.uniffle.client.response.RssSendShuffleDataResponse; +import org.apache.uniffle.client.response.RssStartSortMergeResponse; +import org.apache.uniffle.client.response.RssUnregisterShuffleByAppIdResponse; +import org.apache.uniffle.client.response.RssUnregisterShuffleResponse; +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.merger.MergeState; +import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.RssUtils; + +public class MockedShuffleServerClient implements ShuffleServerClient { + + private Map> shuffleData; + private Map indexes; + private long[] blockIds; + + public MockedShuffleServerClient(int[] partitionIds, ByteBuffer[][] buffers, long[] blockIds) { + if (partitionIds.length != buffers.length) { + throw new RssException("partition id length is not matched"); + } + this.shuffleData = new HashMap<>(); + for (int i = 0; i < partitionIds.length; i++) { + int partition = partitionIds[i]; + shuffleData.put(partition, new ArrayList<>()); + for (ByteBuffer byteBuffer : buffers[i]) { + shuffleData.get(partition).add(byteBuffer); + } + } + this.indexes = new HashMap<>(); + for (Integer pid : shuffleData.keySet()) { + indexes.put(pid, 0); + } + this.blockIds = blockIds; + } + + @Override + public RssGetSortedShuffleDataResponse getSortedShuffleData( + RssGetSortedShuffleDataRequest request) { + int partitionId = request.getPartitionId(); + if (!shuffleData.containsKey(partitionId)) { + throw new RssException("partitionid is not existed"); + } + RssGetSortedShuffleDataResponse response; + int index = indexes.get(partitionId); + if (index < shuffleData.get(partitionId).size()) { + // Offset is ignore in mock client, set unused value 10000; + response = + new RssGetSortedShuffleDataResponse( + StatusCode.SUCCESS, + shuffleData.get(partitionId).get(index), + 10000, + MergeState.DONE.code()); + } else { + response = + new RssGetSortedShuffleDataResponse(StatusCode.SUCCESS, null, -1, MergeState.DONE.code()); + } + indexes.put(partitionId, index + 1); + return response; + } + + @Override + public RssUnregisterShuffleResponse unregisterShuffle(RssUnregisterShuffleRequest request) { + return null; + } + + @Override + public RssRegisterShuffleResponse registerShuffle(RssRegisterShuffleRequest request) { + return null; + } + + @Override + public RssUnregisterShuffleByAppIdResponse unregisterShuffleByAppId( + RssUnregisterShuffleByAppIdRequest request) { + return null; + } + + @Override + public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest request) { + return null; + } + + @Override + public RssSendCommitResponse sendCommit(RssSendCommitRequest request) { + return null; + } + + @Override + public RssAppHeartBeatResponse sendHeartBeat(RssAppHeartBeatRequest request) { + return null; + } + + @Override + public RssFinishShuffleResponse finishShuffle(RssFinishShuffleRequest request) { + return null; + } + + @Override + public RssReportShuffleResultResponse reportShuffleResult(RssReportShuffleResultRequest request) { + return null; + } + + @Override + public RssGetShuffleResultResponse getShuffleResult(RssGetShuffleResultRequest request) { + + try { + Roaring64NavigableMap bitMap = Roaring64NavigableMap.bitmapOf(); + for (long blockId : blockIds) { + bitMap.add(blockId); + } + return new RssGetShuffleResultResponse(StatusCode.SUCCESS, RssUtils.serializeBitMap(bitMap)); + } catch (IOException e) { + throw new RssException(e); + } + } + + @Override + public RssGetShuffleResultResponse getShuffleResultForMultiPart( + RssGetShuffleResultForMultiPartRequest request) { + return null; + } + + @Override + public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest request) { + return null; + } + + @Override + public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request) { + return null; + } + + @Override + public RssStartSortMergeResponse startSortMerge(RssStartSortMergeRequest request) { + return null; + } + + @Override + public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData( + RssGetInMemoryShuffleDataRequest request) { + return null; + } + + @Override + public void close() {} + + @Override + public String getClientInfo() { + return null; + } +} diff --git a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java new file mode 100644 index 0000000000..7d4cbf980b --- /dev/null +++ b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record.reader; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; + +import org.roaringbitmap.longlong.Roaring64NavigableMap; + +import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking; +import org.apache.uniffle.client.api.ShuffleWriteClient; +import org.apache.uniffle.client.response.SendShuffleDataResult; +import org.apache.uniffle.common.PartitionRange; +import org.apache.uniffle.common.RemoteStorageInfo; +import org.apache.uniffle.common.ShuffleAssignmentsInfo; +import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.ShuffleDataDistributionType; +import org.apache.uniffle.common.ShuffleServerInfo; + +public class MockedShuffleWriteClient implements ShuffleWriteClient { + + private Map>> blockIds = new HashMap<>(); + + @Override + public SendShuffleDataResult sendShuffleData( + String appId, + List shuffleBlockInfoList, + Supplier needCancelRequest) { + return null; + } + + @Override + public void sendAppHeartbeat(String appId, long timeoutMs) {} + + @Override + public void registerApplicationInfo(String appId, long timeoutMs, String user) {} + + @Override + public void registerShuffle( + ShuffleServerInfo shuffleServerInfo, + String appId, + int shuffleId, + List partitionRanges, + RemoteStorageInfo remoteStorage, + ShuffleDataDistributionType dataDistributionType, + int maxConcurrencyPerPartitionToWrite, + int stageAttemptNumber, + String keyClassName, + String valueClassName, + String comparatorClassName, + int mergedBlockSize, + String mergeClassLoader) {} + + @Override + public boolean sendCommit( + Set shuffleServerInfoSet, String appId, int shuffleId, int numMaps) { + return false; + } + + @Override + public void registerCoordinators(String coordinators, long retryIntervalMs, int retryTimes) {} + + @Override + public Map fetchClientConf(int timeoutMs) { + return null; + } + + @Override + public RemoteStorageInfo fetchRemoteStorage(String appId) { + return null; + } + + @Override + public void reportShuffleResult( + Map>> serverToPartitionToBlockIds, + String appId, + int shuffleId, + long taskAttemptId, + int bitmapNum) { + this.blockIds.putIfAbsent(appId, new HashMap<>()); + this.blockIds.get(appId).putIfAbsent(shuffleId, new HashMap<>()); + + for (Map> partitionToBlockIds : serverToPartitionToBlockIds.values()) { + for (Map.Entry> entry : partitionToBlockIds.entrySet()) { + int partitionId = entry.getKey(); + this.blockIds + .get(appId) + .get(shuffleId) + .putIfAbsent(partitionId, Roaring64NavigableMap.bitmapOf()); + for (long blockId : entry.getValue()) { + this.blockIds.get(appId).get(shuffleId).get(partitionId).add(blockId); + } + } + } + } + + @Override + public ShuffleAssignmentsInfo getShuffleAssignments( + String appId, + int shuffleId, + int partitionNum, + int partitionNumPerRange, + Set requiredTags, + int assignmentShuffleServerNumber, + int estimateTaskConcurrency, + Set faultyServerIds, + int stageId, + int stageAttemptNumber, + boolean reassign, + long retryIntervalMs, + int retryTimes) { + return null; + } + + @Override + public ShuffleAssignmentsInfo getShuffleAssignments( + String appId, + int shuffleId, + int partitionNum, + int partitionNumPerRange, + Set requiredTags, + int assignmentShuffleServerNumber, + int estimateTaskConcurrency) { + return null; + } + + @Override + public void startSortMerge( + Set serverInfos, + String appId, + int shuffleId, + int partitionId, + Roaring64NavigableMap expectedTaskIds) {} + + @Override + public Roaring64NavigableMap getShuffleResult( + String clientType, + Set shuffleServerInfoSet, + String appId, + int shuffleId, + int partitionId) { + return this.blockIds.get(appId).get(shuffleId).get(partitionId); + } + + @Override + public Roaring64NavigableMap getShuffleResultForMultiPart( + String clientType, + Map> serverToPartitions, + String appId, + int shuffleId, + Set failedPartitions, + PartitionDataReplicaRequirementTracking replicaRequirementTracking) { + return null; + } + + @Override + public void close() {} + + @Override + public void unregisterShuffle(String appId, int shuffleId) {} + + @Override + public void unregisterShuffle(String appId) {} +} diff --git a/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java b/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java new file mode 100644 index 0000000000..65d513ced1 --- /dev/null +++ b/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record.reader; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; +import org.apache.hadoop.io.IntWritable; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import org.apache.uniffle.client.api.ShuffleServerClient; +import org.apache.uniffle.client.record.writer.Combiner; +import org.apache.uniffle.client.record.writer.SumByKeyCombiner; +import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.merger.Merger; +import org.apache.uniffle.common.merger.Segment; +import org.apache.uniffle.common.serializer.SerializerFactory; +import org.apache.uniffle.common.serializer.SerializerInstance; +import org.apache.uniffle.common.serializer.SerializerUtils; + +import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; + +public class RMRecordsReaderTest { + + private static final String APP_ID = "app1"; + private static final int SHUFFLE_ID = 0; + private static final int RECORDS_NUM = 1009; + + @Timeout(30) + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + }) + public void testNormalReadWithoutCombine(String classes) throws Exception { + // 1 basic parameter + final String[] classArray = classes.split(","); + final String keyClassName = classArray[0]; + final String valueClassName = classArray[1]; + final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final Class keyClass = SerializerUtils.getClassByName(keyClassName); + final Class valueClass = SerializerUtils.getClassByName(valueClassName); + final Comparator comparator = SerializerUtils.getComparator(keyClass); + final Combiner combiner = null; + final int partitionId = 0; + final RssConf rssConf = new RssConf(); + final List serverInfos = new ArrayList<>(); + serverInfos.add(new ShuffleServerInfo("dummy", -1)); + + // 2 construct reader + RMRecordsReader reader = + new RMRecordsReader( + APP_ID, + SHUFFLE_ID, + Sets.newHashSet(partitionId), + ImmutableMap.of(partitionId, serverInfos), + rssConf, + keyClass, + valueClass, + comparator, + raw, + combiner, + false, + null); + byte[] buffers = genSortedRecordBytes(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, 1); + ShuffleServerClient serverClient = + new MockedShuffleServerClient( + new int[] {partitionId}, new ByteBuffer[][] {{ByteBuffer.wrap(buffers)}}, null); + RMRecordsReader readerSpy = spy(reader); + doReturn(serverClient).when(readerSpy).createShuffleServerClient(any()); + + // 3 run reader and verify result + readerSpy.start(); + int index = 0; + KeyValueReader keyValueReader = readerSpy.keyValueReader(); + while (keyValueReader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), keyValueReader.getCurrentKey()); + assertEquals(SerializerUtils.genData(valueClass, index), keyValueReader.getCurrentValue()); + index++; + } + assertEquals(RECORDS_NUM, index); + } + + @Timeout(30) + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + }) + public void testNormalReadWithCombine(String classes) throws Exception { + // 1 basic parameter + final String[] classArray = classes.split(","); + final String keyClassName = classArray[0]; + final String valueClassName = classArray[1]; + final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final Class keyClass = SerializerUtils.getClassByName(keyClassName); + final Class valueClass = SerializerUtils.getClassByName(valueClassName); + final Comparator comparator = SerializerUtils.getComparator(keyClass); + SerializerFactory factory = new SerializerFactory(new RssConf()); + org.apache.uniffle.common.serializer.Serializer serializer = factory.getSerializer(keyClass); + SerializerInstance serializerInstance = serializer.newInstance(); + final Combiner combiner = new SumByKeyCombiner(raw, serializerInstance, keyClass, valueClass); + final int partitionId = 0; + final RssConf rssConf = new RssConf(); + final List serverInfos = new ArrayList<>(); + serverInfos.add(new ShuffleServerInfo("dummy", -1)); + + // 2 construct reader + List segments = new ArrayList<>(); + segments.add( + SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 0L, 0, 2, RECORDS_NUM)); + segments.add( + SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM)); + segments.add( + SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM)); + ByteArrayOutputStream output = new ByteArrayOutputStream(); + Merger.merge(rssConf, output, segments, keyClass, valueClass, comparator, false); + output.close(); + byte[] buffers = output.toByteArray(); + ShuffleServerClient serverClient = + new MockedShuffleServerClient( + new int[] {partitionId}, new ByteBuffer[][] {{ByteBuffer.wrap(buffers)}}, null); + RMRecordsReader reader = + new RMRecordsReader( + APP_ID, + SHUFFLE_ID, + Sets.newHashSet(partitionId), + ImmutableMap.of(partitionId, serverInfos), + rssConf, + keyClass, + valueClass, + comparator, + raw, + combiner, + false, + null); + RMRecordsReader readerSpy = spy(reader); + doReturn(serverClient).when(readerSpy).createShuffleServerClient(any()); + + // 3 run reader and verify result + readerSpy.start(); + int index = 0; + KeyValueReader keyValueReader = readerSpy.keyValueReader(); + while (keyValueReader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), keyValueReader.getCurrentKey()); + Object value = SerializerUtils.genData(valueClass, index); + Object newValue = value; + if (index % 2 == 0) { + if (value instanceof IntWritable) { + newValue = new IntWritable(((IntWritable) value).get() * 2); + } else { + newValue = (int) value * 2; + } + } + assertEquals(newValue, keyValueReader.getCurrentValue()); + index++; + } + assertEquals(RECORDS_NUM * 2, index); + } + + @Timeout(30) + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + }) + public void testReadMulitPartitionWithoutCombine(String classes) throws Exception { + // 1 basic parameter + final String[] classArray = classes.split(","); + final String keyClassName = classArray[0]; + final String valueClassName = classArray[1]; + final Class keyClass = SerializerUtils.getClassByName(keyClassName); + final Class valueClass = SerializerUtils.getClassByName(valueClassName); + final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final Comparator comparator = SerializerUtils.getComparator(keyClass); + final Combiner combiner = null; + final int partitionId = 0; + final RssConf rssConf = new RssConf(); + final List serverInfos = new ArrayList<>(); + serverInfos.add(new ShuffleServerInfo("dummy", -1)); + + // 2 construct reader + RMRecordsReader reader = + new RMRecordsReader( + APP_ID, + SHUFFLE_ID, + Sets.newHashSet(partitionId, partitionId + 1, partitionId + 2), + ImmutableMap.of( + partitionId, + serverInfos, + partitionId + 1, + serverInfos, + partitionId + 2, + serverInfos), + rssConf, + keyClass, + valueClass, + comparator, + raw, + combiner, + false, + null); + RMRecordsReader readerSpy = spy(reader); + ByteBuffer[][] buffers = new ByteBuffer[3][2]; + for (int i = 0; i < 3; i++) { + buffers[i][0] = + ByteBuffer.wrap( + genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 1)); + buffers[i][1] = + ByteBuffer.wrap( + genSortedRecordBytes( + rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 1)); + } + ShuffleServerClient serverClient = + new MockedShuffleServerClient( + new int[] {partitionId, partitionId + 1, partitionId + 2}, buffers, null); + doReturn(serverClient).when(readerSpy).createShuffleServerClient(any()); + + // 3 run reader and verify result + readerSpy.start(); + int index = 0; + KeyValueReader keyValueReader = readerSpy.keyValueReader(); + while (keyValueReader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), keyValueReader.getCurrentKey()); + assertEquals(SerializerUtils.genData(valueClass, index), keyValueReader.getCurrentValue()); + index++; + } + assertEquals(RECORDS_NUM * 6, index); + } + + @Timeout(30) + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + }) + public void testReadMulitPartitionWithCombine(String classes) throws Exception { + // 1 basic parameter + final String[] classArray = classes.split(","); + final String keyClassName = classArray[0]; + final String valueClassName = classArray[1]; + final Class keyClass = SerializerUtils.getClassByName(keyClassName); + final Class valueClass = SerializerUtils.getClassByName(valueClassName); + final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final Comparator comparator = SerializerUtils.getComparator(keyClass); + SerializerFactory factory = new SerializerFactory(new RssConf()); + org.apache.uniffle.common.serializer.Serializer serializer = factory.getSerializer(keyClass); + SerializerInstance serializerInstance = serializer.newInstance(); + final Combiner combiner = new SumByKeyCombiner(raw, serializerInstance, keyClass, valueClass); + final int partitionId = 0; + final RssConf rssConf = new RssConf(); + final List serverInfos = new ArrayList<>(); + serverInfos.add(new ShuffleServerInfo("dummy", -1)); + + // 2 construct reader + RMRecordsReader reader = + new RMRecordsReader( + APP_ID, + SHUFFLE_ID, + Sets.newHashSet(partitionId, partitionId + 1, partitionId + 2), + ImmutableMap.of( + partitionId, + serverInfos, + partitionId + 1, + serverInfos, + partitionId + 2, + serverInfos), + rssConf, + keyClass, + valueClass, + comparator, + raw, + combiner, + false, + null); + RMRecordsReader readerSpy = spy(reader); + ByteBuffer[][] buffers = new ByteBuffer[3][2]; + for (int i = 0; i < 3; i++) { + buffers[i][0] = + ByteBuffer.wrap( + genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 2)); + buffers[i][1] = + ByteBuffer.wrap( + genSortedRecordBytes( + rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 2)); + } + ShuffleServerClient serverClient = + new MockedShuffleServerClient( + new int[] {partitionId, partitionId + 1, partitionId + 2}, buffers, null); + doReturn(serverClient).when(readerSpy).createShuffleServerClient(any()); + + // 3 run reader and verify result + readerSpy.start(); + int index = 0; + KeyValueReader keyValueReader = readerSpy.keyValueReader(); + while (keyValueReader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), keyValueReader.getCurrentKey()); + assertEquals( + SerializerUtils.genData(valueClass, index * 2), keyValueReader.getCurrentValue()); + index++; + } + assertEquals(RECORDS_NUM * 6, index); + } +} diff --git a/client/src/test/java/org/apache/uniffle/client/record/writer/RecordCollectionTest.java b/client/src/test/java/org/apache/uniffle/client/record/writer/RecordCollectionTest.java new file mode 100644 index 0000000000..b5f0b5bd7b --- /dev/null +++ b/client/src/test/java/org/apache/uniffle/client/record/writer/RecordCollectionTest.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record.writer; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import org.apache.uniffle.client.record.RecordBlob; +import org.apache.uniffle.client.record.RecordBuffer; +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.records.RecordsReader; +import org.apache.uniffle.common.records.RecordsWriter; +import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.SerializerUtils; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RecordCollectionTest { + + private static final int RECORDS = 1009; + + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable", + }) + public void testSortAndSerializeRecords(String classes) throws Exception { + // 1 Parse arguments + String[] classArray = classes.split(","); + Class keyClass = SerializerUtils.getClassByName(classArray[0]); + Class valueClass = SerializerUtils.getClassByName(classArray[1]); + + // 2 add record + RecordBuffer recordBuffer = new RecordBuffer(0); + List indexes = new ArrayList<>(); + for (int i = 0; i < RECORDS; i++) { + indexes.add(i); + } + Collections.shuffle(indexes); + for (Integer index : indexes) { + recordBuffer.addRecord( + SerializerUtils.genData(keyClass, index), SerializerUtils.genData(valueClass, index)); + } + + // 3 sort + recordBuffer.sort(SerializerUtils.getComparator(keyClass)); + + // 4 serialize records + RssConf rssConf = new RssConf(); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + RecordsWriter writer = new RecordsWriter(rssConf, outputStream, keyClass, valueClass, false); + recordBuffer.serialize(writer); + writer.close(); + + // 5 check the serialized data + RecordsReader reader = + new RecordsReader<>( + rssConf, + PartialInputStream.newInputStream(ByteBuffer.wrap(outputStream.toByteArray())), + keyClass, + valueClass, + false); + int index = 0; + while (reader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); + assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); + index++; + } + assertEquals(RECORDS, index); + reader.close(); + } + + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable", + }) + public void testSortCombineAndSerializeRecords(String classes) throws Exception { + // 1 Parse arguments + String[] classArray = classes.split(","); + Class keyClass = SerializerUtils.getClassByName(classArray[0]); + Class valueClass = SerializerUtils.getClassByName(classArray[1]); + + // 2 add record + RecordBuffer recordBuffer = new RecordBuffer(0); + List indexes = new ArrayList<>(); + for (int i = 0; i < RECORDS; i++) { + indexes.add(i); + } + Collections.shuffle(indexes); + for (Integer index : indexes) { + int times = index % 3 + 1; + for (int j = 0; j < times; j++) { + recordBuffer.addRecord( + SerializerUtils.genData(keyClass, index), + SerializerUtils.genData(valueClass, index + j)); + } + } + + // 3 sort and combine + recordBuffer.sort(SerializerUtils.getComparator(keyClass)); + RecordBlob recordBlob = new RecordBlob(0); + recordBlob.addRecords(recordBuffer); + recordBlob.combine(new SumByKeyCombiner(), false); + + // 4 serialize records + RssConf rssConf = new RssConf(); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + RecordsWriter writer = new RecordsWriter(rssConf, outputStream, keyClass, valueClass, false); + recordBlob.serialize(writer); + writer.close(); + + // 5 check the serialized data + RecordsReader reader = + new RecordsReader<>( + rssConf, + PartialInputStream.newInputStream(ByteBuffer.wrap(outputStream.toByteArray())), + keyClass, + valueClass, + false); + int index = 0; + while (reader.next()) { + int aimValue = index; + if (index % 3 == 1) { + aimValue = 2 * aimValue + 1; + } + if (index % 3 == 2) { + aimValue = 3 * aimValue + 3; + } + assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); + assertEquals(SerializerUtils.genData(valueClass, aimValue), reader.getCurrentValue()); + index++; + } + reader.close(); + assertEquals(RECORDS, index); + } +} diff --git a/client/src/test/java/org/apache/uniffle/client/record/writer/SumByKeyCombiner.java b/client/src/test/java/org/apache/uniffle/client/record/writer/SumByKeyCombiner.java new file mode 100644 index 0000000000..b81009b142 --- /dev/null +++ b/client/src/test/java/org/apache/uniffle/client/record/writer/SumByKeyCombiner.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.record.writer; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.io.DataInputBuffer; +import org.apache.hadoop.io.IntWritable; + +import org.apache.uniffle.client.record.Record; +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.serializer.SerializerInstance; +import org.apache.uniffle.common.serializer.writable.ComparativeOutputBuffer; + +public class SumByKeyCombiner extends Combiner { + + private final boolean raw; + private final SerializerInstance instance; + private final Class keyClass; + private final Class valueClass; + + public SumByKeyCombiner() { + this(false, null, null, null); + } + + public SumByKeyCombiner( + boolean raw, SerializerInstance instance, Class keyClass, Class valueClass) { + this.raw = raw; + this.instance = instance; + this.keyClass = keyClass; + this.valueClass = valueClass; + } + + @Override + public List combineValues(Iterator recordIterator) { + List collected = new ArrayList<>(); + while (recordIterator.hasNext()) { + Map.Entry> entry = + (Map.Entry>) recordIterator.next(); + Record current = null; + List records = entry.getValue(); + for (Record record : records) { + Record newRecord; + if (raw) { + ComparativeOutputBuffer keyBuffer = (ComparativeOutputBuffer) record.getKey(); + ComparativeOutputBuffer valueBuffer = (ComparativeOutputBuffer) record.getValue(); + DataInputBuffer keyInputBuffer = new DataInputBuffer(); + DataInputBuffer valueInputBuffer = new DataInputBuffer(); + keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength()); + valueInputBuffer.reset(valueBuffer.getData(), 0, valueBuffer.getLength()); + try { + Object key = instance.deserialize(keyInputBuffer, keyClass); + Object value = instance.deserialize(valueInputBuffer, valueClass); + newRecord = Record.create(key, value); + } catch (IOException e) { + throw new RssException(e); + } + } else { + newRecord = record; + } + if (current == null) { + current = newRecord; + collected.add(current); + } else { + int v1 = + newRecord.getValue() instanceof IntWritable + ? ((IntWritable) newRecord.getValue()).get() + : (int) newRecord.getValue(); + int v2 = + current.getValue() instanceof IntWritable + ? ((IntWritable) current.getValue()).get() + : (int) current.getValue(); + if (current.getValue() instanceof IntWritable) { + ((IntWritable) current.getValue()).set(v1 + v2); + } else { + current.setValue(v1 + v2); + } + } + } + } + if (raw) { + List ret = new ArrayList<>(); + for (Record record : collected) { + ComparativeOutputBuffer keyBuffer = new ComparativeOutputBuffer(); + ComparativeOutputBuffer valueBuffer = new ComparativeOutputBuffer(); + try { + instance.serialize(record.getKey(), keyBuffer); + instance.serialize(record.getValue(), valueBuffer); + } catch (IOException e) { + throw new RssException(e); + } + ret.add(Record.create(keyBuffer, valueBuffer)); + } + return ret; + } else { + return collected; + } + } + + @Override + public List combineCombiners(Iterator recordIterator) { + return combineValues(recordIterator); + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java index 1904cf7430..f3e16de320 100644 --- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java +++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java @@ -252,4 +252,29 @@ public class RssClientConf { .defaultValue(false) .withDescription( "Whether to support rss client block send failure retry, default value is false."); + + public static final ConfigOption RSS_CLIENT_REMOTE_MERGE_FETCH_INIT_SLEEP_MS = + ConfigOptions.key("rss.client.remote.merge.fetch.initSleepMs") + .intType() + .defaultValue(100) + .withDescription("the init sleep ms for fetch remote merge records"); + + public static final ConfigOption RSS_CLIENT_REMOTE_MERGE_FETCH_MAX_SLEEP_MS = + ConfigOptions.key("rss.client.remote.merge.fetch.maxSleepMs") + .intType() + .defaultValue(5000) + .withDescription("the max sleep ms for fetch remote merge records"); + + public static final ConfigOption RSS_CLIENT_REMOTE_MERGE_READER_MAX_BUFFER = + ConfigOptions.key("rss.client.remote.merge.reader.maxBuffer") + .intType() + .defaultValue(2) + .withDescription( + "the max size of buffer in queue for one partition when fetch remote merge records"); + + public static final ConfigOption RSS_CLIENT_REMOTE_MERGE_READER_MAX_RECORDS_PER_BUFFER = + ConfigOptions.key("rss.client.remote.merge.reader.maxRecordsPerBuffer") + .intType() + .defaultValue(500) + .withDescription("the max size of records per buffer when fetch remote merge records"); } diff --git a/common/src/main/java/org/apache/uniffle/common/merger/StreamedSegment.java b/common/src/main/java/org/apache/uniffle/common/merger/StreamedSegment.java index 1966a3e511..b4f0117f51 100644 --- a/common/src/main/java/org/apache/uniffle/common/merger/StreamedSegment.java +++ b/common/src/main/java/org/apache/uniffle/common/merger/StreamedSegment.java @@ -19,13 +19,13 @@ import java.io.File; import java.io.IOException; +import java.nio.ByteBuffer; import io.netty.buffer.ByteBuf; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.records.RecordsReader; import org.apache.uniffle.common.serializer.PartialInputStream; -import org.apache.uniffle.common.serializer.PartialInputStreamImpl; public class StreamedSegment extends Segment { @@ -49,11 +49,10 @@ public StreamedSegment( super(blockId); this.byteBuf = byteBuf; this.byteBuf.retain(); - byte[] buffer = byteBuf.array(); this.reader = new RecordsReader<>( rssConf, - PartialInputStreamImpl.newInputStream(buffer, 0, buffer.length), + PartialInputStream.newInputStream(byteBuf.nioBuffer()), keyClass, valueClass, raw); @@ -61,16 +60,17 @@ public StreamedSegment( // The buffer must be sorted by key public StreamedSegment( - RssConf rssConf, byte[] buffer, long blockId, Class keyClass, Class valueClass, boolean raw) + RssConf rssConf, + ByteBuffer byteBuffer, + long blockId, + Class keyClass, + Class valueClass, + boolean raw) throws IOException { super(blockId); this.reader = new RecordsReader<>( - rssConf, - PartialInputStreamImpl.newInputStream(buffer, 0, buffer.length), - keyClass, - valueClass, - raw); + rssConf, PartialInputStream.newInputStream(byteBuffer), keyClass, valueClass, raw); } public StreamedSegment( @@ -87,7 +87,7 @@ public StreamedSegment( this.reader = new RecordsReader( rssConf, - PartialInputStreamImpl.newInputStream(file, start, end), + PartialInputStream.newInputStream(file, start, end), keyClass, valueClass, raw); diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/BufferPartialInputStreamImpl.java b/common/src/main/java/org/apache/uniffle/common/serializer/BufferPartialInputStreamImpl.java new file mode 100644 index 0000000000..81fb826e12 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/serializer/BufferPartialInputStreamImpl.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.serializer; + +import java.io.IOException; +import java.nio.ByteBuffer; + +public class BufferPartialInputStreamImpl extends PartialInputStream { + + private ByteBuffer buffer; + private final long start; // the start of source input stream + private final long end; // the end of source input stream + + public BufferPartialInputStreamImpl(ByteBuffer byteBuffer, long start, long end) + throws IOException { + if (start < 0) { + throw new IOException("Negative position for channel!"); + } + this.buffer = byteBuffer; + this.start = start; + this.end = end; + this.buffer.position((int) start); + } + + @Override + public int read() throws IOException { + if (available() <= 0) { + return -1; + } + return this.buffer.get() & 0xff; + } + + @Override + public int available() throws IOException { + return (int) (end - this.buffer.position()); + } + + @Override + public long getStart() { + return start; + } + + @Override + public long getEnd() { + return end; + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStream.java b/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStream.java index 3319802162..f1a18595ee 100644 --- a/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStream.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStream.java @@ -17,8 +17,12 @@ package org.apache.uniffle.common.serializer; +import java.io.File; +import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; /* * PartialInputStream is a configurable partial input stream, which @@ -32,4 +36,34 @@ public abstract class PartialInputStream extends InputStream { public abstract long getStart(); public abstract long getEnd(); + + public static PartialInputStream newInputStream(File file, long start, long end) + throws IOException { + FileInputStream input = new FileInputStream(file); + FileChannel fc = input.getChannel(); + if (fc == null) { + throw new NullPointerException("channel is null!"); + } + long size = fc.size(); + return new PartialInputStreamImpl( + fc, + start, + Math.min(end, size), + () -> { + input.close(); + }); + } + + public static PartialInputStream newInputStream(File file) throws IOException { + return PartialInputStream.newInputStream(file, 0, file.length()); + } + + public static PartialInputStream newInputStream(ByteBuffer byteBuffer, long start, long end) + throws IOException { + return new BufferPartialInputStreamImpl(byteBuffer, start, Math.min(byteBuffer.limit(), end)); + } + + public static PartialInputStream newInputStream(ByteBuffer byteBuffer) throws IOException { + return new BufferPartialInputStreamImpl(byteBuffer, 0, byteBuffer.limit()); + } } diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStreamImpl.java b/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStreamImpl.java index 6bf8910fed..b16a7c8d8d 100644 --- a/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStreamImpl.java +++ b/common/src/main/java/org/apache/uniffle/common/serializer/PartialInputStreamImpl.java @@ -18,11 +18,8 @@ package org.apache.uniffle.common.serializer; import java.io.Closeable; -import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.channels.FileChannel; import java.nio.channels.SeekableByteChannel; /* @@ -114,33 +111,4 @@ public void close() throws IOException { closeable.close(); } } - - private static PartialInputStreamImpl newInputStream( - SeekableByteChannel ch, long start, long end, Closeable closeable) throws IOException { - if (ch == null) { - throw new NullPointerException("channel is null!"); - } - return new PartialInputStreamImpl(ch, start, end, closeable); - } - - public static PartialInputStreamImpl newInputStream(File file, long start, long end) - throws IOException { - FileInputStream input = new FileInputStream(file); - FileChannel fc = input.getChannel(); - long size = fc.size(); - return newInputStream( - fc, - start, - Math.min(end, size), - () -> { - input.close(); - }); - } - - public static PartialInputStreamImpl newInputStream(byte[] bytes, long start, long end) - throws IOException { - SeekableInMemoryByteChannel ch = new SeekableInMemoryByteChannel(bytes); - int size = bytes.length; - return newInputStream(ch, start, Math.min(end, size), () -> ch.close()); - } } diff --git a/common/src/main/java/org/apache/uniffle/common/serializer/SeekableInMemoryByteChannel.java b/common/src/main/java/org/apache/uniffle/common/serializer/SeekableInMemoryByteChannel.java deleted file mode 100644 index cddd2d61ab..0000000000 --- a/common/src/main/java/org/apache/uniffle/common/serializer/SeekableInMemoryByteChannel.java +++ /dev/null @@ -1,165 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.uniffle.common.serializer; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; -import java.nio.channels.SeekableByteChannel; -import java.util.Arrays; -import java.util.concurrent.atomic.AtomicBoolean; - -public class SeekableInMemoryByteChannel implements SeekableByteChannel { - - private static final int NAIVE_RESIZE_LIMIT = Integer.MAX_VALUE >> 1; - - private byte[] data; - private final AtomicBoolean closed = new AtomicBoolean(); - private int position; - private int size; - - /** - * Constructor taking a byte array. - * - *

This constructor is intended to be used with pre-allocated buffer or when reading from a - * given byte array. - * - * @param data input data or pre-allocated array. - */ - public SeekableInMemoryByteChannel(byte[] data) { - this.data = data; - size = data.length; - } - - @Override - public long position() { - return position; - } - - @Override - public SeekableByteChannel position(long newPosition) throws IOException { - ensureOpen(); - if (newPosition < 0L || newPosition > Integer.MAX_VALUE) { - throw new IllegalArgumentException("Position has to be in range 0.. " + Integer.MAX_VALUE); - } - position = (int) newPosition; - return this; - } - - @Override - public long size() { - return size; - } - - @Override - public SeekableByteChannel truncate(long newSize) { - if (size > newSize) { - size = (int) newSize; - } - repositionIfNecessary(); - return this; - } - - @Override - public int read(ByteBuffer buf) throws IOException { - ensureOpen(); - repositionIfNecessary(); - int wanted = buf.remaining(); - int possible = size - position; - if (possible <= 0) { - return -1; - } - if (wanted > possible) { - wanted = possible; - } - buf.put(data, position, wanted); - position += wanted; - return wanted; - } - - @Override - public void close() { - closed.set(true); - } - - @Override - public boolean isOpen() { - return !closed.get(); - } - - @Override - public int write(ByteBuffer b) throws IOException { - ensureOpen(); - int wanted = b.remaining(); - int possibleWithoutResize = size - position; - if (wanted > possibleWithoutResize) { - int newSize = position + wanted; - if (newSize < 0) { // overflow - resize(Integer.MAX_VALUE); - wanted = Integer.MAX_VALUE - position; - } else { - resize(newSize); - } - } - b.get(data, position, wanted); - position += wanted; - if (size < position) { - size = position; - } - return wanted; - } - - /** - * Obtains the array backing this channel. - * - *

NOTE: The returned buffer is not aligned with containing data, use {@link #size()} to obtain - * the size of data stored in the buffer. - * - * @return internal byte array. - */ - public byte[] array() { - return data; - } - - private void resize(int newLength) { - int len = data.length; - if (len <= 0) { - len = 1; - } - if (newLength < NAIVE_RESIZE_LIMIT) { - while (len < newLength) { - len <<= 1; - } - } else { // avoid overflow - len = newLength; - } - data = Arrays.copyOf(data, len); - } - - private void ensureOpen() throws ClosedChannelException { - if (!isOpen()) { - throw new ClosedChannelException(); - } - } - - private void repositionIfNecessary() { - if (position > size) { - position = size; - } - } -} diff --git a/common/src/test/java/org/apache/uniffle/common/merger/MergerTest.java b/common/src/test/java/org/apache/uniffle/common/merger/MergerTest.java index 1757ad0047..970b05d66c 100644 --- a/common/src/test/java/org/apache/uniffle/common/merger/MergerTest.java +++ b/common/src/test/java/org/apache/uniffle/common/merger/MergerTest.java @@ -30,7 +30,7 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.records.RecordsReader; -import org.apache.uniffle.common.serializer.PartialInputStreamImpl; +import org.apache.uniffle.common.serializer.PartialInputStream; import org.apache.uniffle.common.serializer.SerializerUtils; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -96,11 +96,7 @@ public void testMergeSegmentToFile(String classes, @TempDir File tmpDir) throws // 3 Check the merged file RecordsReader reader = new RecordsReader( - rssConf, - PartialInputStreamImpl.newInputStream(mergedFile, 0, mergedFile.length()), - keyClass, - valueClass, - false); + rssConf, PartialInputStream.newInputStream(mergedFile), keyClass, valueClass, false); int index = 0; while (reader.next()) { assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); diff --git a/common/src/test/java/org/apache/uniffle/common/records/RecordsReaderWriterTest.java b/common/src/test/java/org/apache/uniffle/common/records/RecordsReaderWriterTest.java index 0fc6e16bea..a5f26b9a43 100644 --- a/common/src/test/java/org/apache/uniffle/common/records/RecordsReaderWriterTest.java +++ b/common/src/test/java/org/apache/uniffle/common/records/RecordsReaderWriterTest.java @@ -21,6 +21,7 @@ import java.io.File; import java.io.FileOutputStream; import java.io.OutputStream; +import java.nio.ByteBuffer; import java.util.Random; import org.apache.hadoop.io.DataInputBuffer; @@ -30,7 +31,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.apache.uniffle.common.config.RssConf; -import org.apache.uniffle.common.serializer.PartialInputStreamImpl; +import org.apache.uniffle.common.serializer.PartialInputStream; import org.apache.uniffle.common.serializer.Serializer; import org.apache.uniffle.common.serializer.SerializerFactory; import org.apache.uniffle.common.serializer.SerializerInstance; @@ -74,11 +75,11 @@ public void testWriteAndReadRecordFile1(String classes, @TempDir File tmpDir) th // 3 Read // 3.1 read from start - PartialInputStreamImpl inputStream = + PartialInputStream inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream(tmpFile, 0, tmpFile.length()) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), 0, Long.MAX_VALUE); + ? PartialInputStream.newInputStream(tmpFile) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray())); RecordsReader reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, false); int index = 0; while (reader.next()) { @@ -92,11 +93,11 @@ public void testWriteAndReadRecordFile1(String classes, @TempDir File tmpDir) th // 3.2 read from end inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream(tmpFile, offsets[RECORDS - 1], tmpFile.length()) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), + ? PartialInputStream.newInputStream(tmpFile, offsets[RECORDS - 1], tmpFile.length()) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), offsets[RECORDS - 1], - Long.MAX_VALUE); + ((ByteArrayOutputStream) outputStream).size()); reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, false); assertFalse(reader.next()); reader.close(); @@ -116,9 +117,11 @@ public void testWriteAndReadRecordFile1(String classes, @TempDir File tmpDir) th long offset = indexAndOffset[1]; inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream(tmpFile, offset, tmpFile.length()) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), offset, Long.MAX_VALUE); + ? PartialInputStream.newInputStream(tmpFile, offset, tmpFile.length()) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), + offset, + ((ByteArrayOutputStream) outputStream).size()); reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, false); while (reader.next()) { assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); @@ -163,11 +166,11 @@ public void testWriteAndReadRecordFile2(String classes, @TempDir File tmpDir) th // 3 Read // 3.1 read from start - PartialInputStreamImpl inputStream = + PartialInputStream inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream(tmpFile, 0, tmpFile.length()) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), 0, Long.MAX_VALUE); + ? PartialInputStream.newInputStream(tmpFile) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray())); RecordsReader reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true); int index = 0; while (reader.next()) { @@ -190,11 +193,11 @@ public void testWriteAndReadRecordFile2(String classes, @TempDir File tmpDir) th // 3.2 read from end inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream(tmpFile, offsets[RECORDS - 1], tmpFile.length()) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), + ? PartialInputStream.newInputStream(tmpFile, offsets[RECORDS - 1], tmpFile.length()) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), offsets[RECORDS - 1], - Long.MAX_VALUE); + ((ByteArrayOutputStream) outputStream).size()); reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true); assertFalse(reader.next()); reader.close(); @@ -214,9 +217,11 @@ public void testWriteAndReadRecordFile2(String classes, @TempDir File tmpDir) th long offset = indexAndOffset[1]; inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream(tmpFile, offset, tmpFile.length()) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), offset, Long.MAX_VALUE); + ? PartialInputStream.newInputStream(tmpFile, offset, tmpFile.length()) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), + offset, + ((ByteArrayOutputStream) outputStream).size()); reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true); while (reader.next()) { DataOutputBuffer keyBuffer = (DataOutputBuffer) reader.getCurrentKey(); @@ -275,11 +280,11 @@ public void testWriteAndReadRecordFile3(String classes, @TempDir File tmpDir) th // 3 Read // 3.1 read from start - PartialInputStreamImpl inputStream = + PartialInputStream inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream(tmpFile, 0, tmpFile.length()) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), 0, Long.MAX_VALUE); + ? PartialInputStream.newInputStream(tmpFile) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray())); RecordsReader reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, false); int index = 0; while (reader.next()) { @@ -293,11 +298,11 @@ public void testWriteAndReadRecordFile3(String classes, @TempDir File tmpDir) th // 3.2 read from end inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream(tmpFile, offsets[RECORDS - 1], tmpFile.length()) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), + ? PartialInputStream.newInputStream(tmpFile, offsets[RECORDS - 1], tmpFile.length()) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), offsets[RECORDS - 1], - Long.MAX_VALUE); + ((ByteArrayOutputStream) outputStream).size()); reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, false); assertFalse(reader.next()); reader.close(); @@ -317,9 +322,11 @@ public void testWriteAndReadRecordFile3(String classes, @TempDir File tmpDir) th long offset = indexAndOffset[1]; inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream(tmpFile, offset, tmpFile.length()) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), offset, Long.MAX_VALUE); + ? PartialInputStream.newInputStream(tmpFile, offset, tmpFile.length()) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), + offset, + ((ByteArrayOutputStream) outputStream).size()); reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, false); while (reader.next()) { assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); @@ -368,11 +375,11 @@ public void testWriteAndReadRecordFile4(String classes, @TempDir File tmpDir) th // 3 Read // 3.1 read from start - PartialInputStreamImpl inputStream = + PartialInputStream inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream(tmpFile, 0, tmpFile.length()) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), 0, Long.MAX_VALUE); + ? PartialInputStream.newInputStream(tmpFile) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray())); RecordsReader reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true); int index = 0; while (reader.next()) { @@ -395,11 +402,11 @@ public void testWriteAndReadRecordFile4(String classes, @TempDir File tmpDir) th // 3.2 read from end inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream(tmpFile, offsets[RECORDS - 1], tmpFile.length()) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), + ? PartialInputStream.newInputStream(tmpFile, offsets[RECORDS - 1], tmpFile.length()) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), offsets[RECORDS - 1], - Long.MAX_VALUE); + ((ByteArrayOutputStream) outputStream).size()); reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true); assertFalse(reader.next()); reader.close(); @@ -419,9 +426,11 @@ public void testWriteAndReadRecordFile4(String classes, @TempDir File tmpDir) th long offset = indexAndOffset[1]; inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream(tmpFile, offset, tmpFile.length()) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), offset, Long.MAX_VALUE); + ? PartialInputStream.newInputStream(tmpFile, offset, tmpFile.length()) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), + offset, + ((ByteArrayOutputStream) outputStream).size()); reader = new RecordsReader(rssConf, inputStream, keyClass, valueClass, true); while (reader.next()) { DataOutputBuffer keyBuffer = (DataOutputBuffer) reader.getCurrentKey(); diff --git a/common/src/test/java/org/apache/uniffle/common/serializer/PartialInputStreamTest.java b/common/src/test/java/org/apache/uniffle/common/serializer/PartialInputStreamTest.java index 16387828e4..6be871a93b 100644 --- a/common/src/test/java/org/apache/uniffle/common/serializer/PartialInputStreamTest.java +++ b/common/src/test/java/org/apache/uniffle/common/serializer/PartialInputStreamTest.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.FileOutputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Random; import org.junit.jupiter.api.BeforeAll; @@ -32,42 +33,44 @@ public class PartialInputStreamTest { private static final int BYTES_LEN = 10240; - private static byte[] testBuffer = new byte[BYTES_LEN]; + private static ByteBuffer testBuffer; private static final int LOOP = 10; @TempDir private static File tempDir; private static File tempFile; @BeforeAll public static void initData() throws IOException { + byte[] bytes = new byte[BYTES_LEN]; for (int i = 0; i < BYTES_LEN; i++) { - testBuffer[i] = (byte) (i & 0x7F); + bytes[i] = (byte) (i & 0x7F); } + testBuffer = ByteBuffer.wrap(bytes); tempFile = new File(tempDir, "data"); FileOutputStream output = new FileOutputStream(tempFile); - output.write(testBuffer); + output.write(bytes); output.close(); } @Test public void testReadMemroyInputStream() throws IOException { // 1 test whole file - testRandomReadMemory(testBuffer, 0, testBuffer.length); + testRandomReadMemory(testBuffer, 0, BYTES_LEN); // 2 test from start to random end Random random = new Random(); for (int i = 0; i < LOOP; i++) { - testRandomReadMemory(testBuffer, 0, random.nextInt(testBuffer.length - 1)); + testRandomReadMemory(testBuffer, 0, random.nextInt(BYTES_LEN - 1)); } // 3 test from random start to end for (int i = 0; i < LOOP; i++) { - testRandomReadMemory(testBuffer, random.nextInt(testBuffer.length - 1), testBuffer.length); + testRandomReadMemory(testBuffer, random.nextInt(BYTES_LEN - 1), BYTES_LEN); } // 4 test from random start to random end for (int i = 0; i < LOOP; i++) { - int r1 = random.nextInt(testBuffer.length - 2) + 1; - int r2 = random.nextInt(testBuffer.length - 2) + 1; + int r1 = random.nextInt(BYTES_LEN - 2) + 1; + int r2 = random.nextInt(BYTES_LEN - 2) + 1; testRandomReadMemory(testBuffer, Math.min(r1, r2), Math.max(r1, r2)); } @@ -75,25 +78,19 @@ public void testReadMemroyInputStream() throws IOException { testRandomReadMemory(testBuffer, 0, 0); // 6 Test when bytes is from end to end - testRandomReadMemory(testBuffer, testBuffer.length, testBuffer.length); + testRandomReadMemory(testBuffer, BYTES_LEN, BYTES_LEN); // 7 Test when bytes is from random to this random for (int i = 0; i < LOOP; i++) { - int r = random.nextInt(testBuffer.length - 2) + 1; + int r = random.nextInt(BYTES_LEN - 2) + 1; testRandomReadMemory(testBuffer, r, r); } } @Test public void testReadNullBytes() throws IOException { - byte[] bytes = new byte[BYTES_LEN]; - for (int i = 0; i < BYTES_LEN; i++) { - bytes[i] = (byte) (i & 0x7F); - } - // Test when bytes is byte[0] - PartialInputStreamImpl input = - PartialInputStreamImpl.newInputStream(new byte[0], 0, bytes.length); + PartialInputStream input = PartialInputStream.newInputStream(ByteBuffer.wrap(new byte[0])); assertEquals(0, input.available()); assertEquals(-1, input.read()); input.close(); @@ -102,23 +99,23 @@ public void testReadNullBytes() throws IOException { @Test public void testReadFileInputStream() throws IOException { // 1 test whole file - testRandomReadFile(tempFile, 0, testBuffer.length); + testRandomReadFile(tempFile, 0, BYTES_LEN); // 2 test from start to random end Random random = new Random(); for (int i = 0; i < LOOP; i++) { - testRandomReadFile(tempFile, 0, random.nextInt(testBuffer.length - 1)); + testRandomReadFile(tempFile, 0, random.nextInt(BYTES_LEN - 1)); } // 3 test from random start to end for (int i = 0; i < LOOP; i++) { - testRandomReadFile(tempFile, random.nextInt(testBuffer.length - 1), testBuffer.length); + testRandomReadFile(tempFile, random.nextInt(BYTES_LEN - 1), BYTES_LEN); } // 4 test from random start to random end for (int i = 0; i < LOOP; i++) { - int r1 = random.nextInt(testBuffer.length - 2) + 1; - int r2 = random.nextInt(testBuffer.length - 2) + 1; + int r1 = random.nextInt(BYTES_LEN - 2) + 1; + int r2 = random.nextInt(BYTES_LEN - 2) + 1; testRandomReadFile(tempFile, Math.min(r1, r2), Math.max(r1, r2)); } @@ -126,36 +123,37 @@ public void testReadFileInputStream() throws IOException { testRandomReadFile(tempFile, 0, 0); // 6 Test when bytes is from end to end - testRandomReadFile(tempFile, testBuffer.length, testBuffer.length); + testRandomReadFile(tempFile, BYTES_LEN, BYTES_LEN); // 7 Test when bytes is from random to this random for (int i = 0; i < LOOP; i++) { - int r = random.nextInt(testBuffer.length - 2) + 1; + int r = random.nextInt(BYTES_LEN - 2) + 1; testRandomReadFile(tempFile, r, r); } } - private void testRandomReadMemory(byte[] bytes, long start, long end) throws IOException { - PartialInputStreamImpl input = PartialInputStreamImpl.newInputStream(bytes, start, end); + private void testRandomReadMemory(ByteBuffer byteBuffer, long start, long end) + throws IOException { + PartialInputStream input = PartialInputStream.newInputStream(byteBuffer, start, end); testRandomReadOneBytePerTime(input, start, end); input.close(); - input = PartialInputStreamImpl.newInputStream(bytes, start, end); + input = PartialInputStream.newInputStream(byteBuffer, start, end); testRandomReadMultiBytesPerTime(input, start, end); input.close(); } private void testRandomReadFile(File file, long start, long end) throws IOException { - PartialInputStreamImpl input = PartialInputStreamImpl.newInputStream(file, start, end); + PartialInputStream input = PartialInputStream.newInputStream(file, start, end); testRandomReadOneBytePerTime(input, start, end); input.close(); - input = PartialInputStreamImpl.newInputStream(file, start, end); + input = PartialInputStream.newInputStream(file, start, end); testRandomReadMultiBytesPerTime(input, start, end); input.close(); } - private void testRandomReadOneBytePerTime(PartialInputStreamImpl input, long start, long end) + private void testRandomReadOneBytePerTime(PartialInputStream input, long start, long end) throws IOException { // test read one byte per time long index = start; @@ -173,7 +171,7 @@ private void testRandomReadOneBytePerTime(PartialInputStreamImpl input, long sta } } - void testRandomReadMultiBytesPerTime(PartialInputStreamImpl input, long start, long end) + void testRandomReadMultiBytesPerTime(PartialInputStream input, long start, long end) throws IOException { // test read multi bytes per times long index = start; diff --git a/common/src/test/java/org/apache/uniffle/common/serializer/SerializerUtils.java b/common/src/test/java/org/apache/uniffle/common/serializer/SerializerUtils.java index d5675182bf..e6679a138b 100644 --- a/common/src/test/java/org/apache/uniffle/common/serializer/SerializerUtils.java +++ b/common/src/test/java/org/apache/uniffle/common/serializer/SerializerUtils.java @@ -22,6 +22,7 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.nio.ByteBuffer; import java.util.Comparator; import com.google.common.base.Objects; @@ -174,7 +175,8 @@ public static Segment genMemorySegment( throws IOException { ByteArrayOutputStream output = new ByteArrayOutputStream(); genSortedRecord(rssConf, keyClass, valueClass, start, interval, length, output, 1); - return new StreamedSegment(rssConf, output.toByteArray(), blockId, keyClass, valueClass, raw); + return new StreamedSegment( + rssConf, ByteBuffer.wrap(output.toByteArray()), blockId, keyClass, valueClass, raw); } public static Segment genFileSegment( diff --git a/common/src/test/java/org/apache/uniffle/common/serializer/WritableSerializerTest.java b/common/src/test/java/org/apache/uniffle/common/serializer/WritableSerializerTest.java index a3f719a0ed..99853f3c6d 100644 --- a/common/src/test/java/org/apache/uniffle/common/serializer/WritableSerializerTest.java +++ b/common/src/test/java/org/apache/uniffle/common/serializer/WritableSerializerTest.java @@ -21,6 +21,7 @@ import java.io.File; import java.io.FileOutputStream; import java.io.OutputStream; +import java.nio.ByteBuffer; import java.util.Random; import org.apache.hadoop.io.DataInputBuffer; @@ -75,12 +76,13 @@ public void testSerDeKeyValues1(String classes, @TempDir File tmpDir) throws Exc // 3 Random read for (int i = 0; i < LOOP; i++) { long off = offsets[i]; - PartialInputStreamImpl inputStream = + PartialInputStream inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream( - new File(tmpDir, "tmp.data"), off, Long.MAX_VALUE) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), off, Long.MAX_VALUE); + ? PartialInputStream.newInputStream(new File(tmpDir, "tmp.data"), off, Long.MAX_VALUE) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), + off, + ((ByteArrayOutputStream) outputStream).size()); DeserializationStream deserializationStream = instance.deserializeStream(inputStream, keyClass, valueClass, false); for (int j = i + 1; j < LOOP; j++) { @@ -124,12 +126,13 @@ public void testSerDeKeyValues2(String classes, @TempDir File tmpDir) throws Exc // 3 Random read for (int i = 0; i < LOOP; i++) { long off = offsets[i]; - PartialInputStreamImpl inputStream = + PartialInputStream inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream( - new File(tmpDir, "tmp.data"), off, Long.MAX_VALUE) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), off, Long.MAX_VALUE); + ? PartialInputStream.newInputStream(new File(tmpDir, "tmp.data"), off, Long.MAX_VALUE) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), + off, + ((ByteArrayOutputStream) outputStream).size()); DeserializationStream deserializationStream = instance.deserializeStream(inputStream, keyClass, valueClass, true); @@ -184,12 +187,13 @@ public void testSerDeKeyValues3(String classes, @TempDir File tmpDir) throws Exc // 3 Random read for (int i = 0; i < LOOP; i++) { long off = offsets[i]; - PartialInputStreamImpl inputStream = + PartialInputStream inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream( - new File(tmpDir, "tmp.data"), off, Long.MAX_VALUE) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), off, Long.MAX_VALUE); + ? PartialInputStream.newInputStream(new File(tmpDir, "tmp.data"), off, Long.MAX_VALUE) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), + off, + ((ByteArrayOutputStream) outputStream).size()); DeserializationStream deserializationStream = instance.deserializeStream(inputStream, keyClass, valueClass, false); for (int j = i + 1; j < LOOP; j++) { @@ -237,12 +241,13 @@ public void testSerDeKeyValues4(String classes, @TempDir File tmpDir) throws Exc // 3 Random read for (int i = 0; i < LOOP; i++) { long off = offsets[i]; - PartialInputStreamImpl inputStream = + PartialInputStream inputStream = isFileMode - ? PartialInputStreamImpl.newInputStream( - new File(tmpDir, "tmp.data"), off, Long.MAX_VALUE) - : PartialInputStreamImpl.newInputStream( - ((ByteArrayOutputStream) outputStream).toByteArray(), off, Long.MAX_VALUE); + ? PartialInputStream.newInputStream(new File(tmpDir, "tmp.data"), off, Long.MAX_VALUE) + : PartialInputStream.newInputStream( + ByteBuffer.wrap(((ByteArrayOutputStream) outputStream).toByteArray()), + off, + ((ByteArrayOutputStream) outputStream).size()); DeserializationStream deserializationStream = instance.deserializeStream(inputStream, keyClass, valueClass, true); diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java new file mode 100644 index 0000000000..dfc56d6c61 --- /dev/null +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java @@ -0,0 +1,917 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.test; + +import java.io.File; +import java.io.IOException; +import java.net.ServerSocket; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.apache.hadoop.io.IntWritable; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.roaringbitmap.longlong.Roaring64NavigableMap; + +import org.apache.uniffle.client.factory.ShuffleClientFactory; +import org.apache.uniffle.client.impl.ShuffleWriteClientImpl; +import org.apache.uniffle.client.record.reader.KeyValueReader; +import org.apache.uniffle.client.record.reader.RMRecordsReader; +import org.apache.uniffle.client.record.writer.Combiner; +import org.apache.uniffle.client.record.writer.SumByKeyCombiner; +import org.apache.uniffle.common.ClientType; +import org.apache.uniffle.common.PartitionRange; +import org.apache.uniffle.common.RemoteStorageInfo; +import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.ShuffleDataDistributionType; +import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.rpc.ServerType; +import org.apache.uniffle.common.serializer.Serializer; +import org.apache.uniffle.common.serializer.SerializerFactory; +import org.apache.uniffle.common.serializer.SerializerInstance; +import org.apache.uniffle.common.serializer.SerializerUtils; +import org.apache.uniffle.common.util.BlockIdLayout; +import org.apache.uniffle.common.util.ChecksumUtils; +import org.apache.uniffle.coordinator.CoordinatorConf; +import org.apache.uniffle.server.ShuffleServerConf; +import org.apache.uniffle.storage.util.StorageType; + +import static org.apache.uniffle.coordinator.CoordinatorConf.COORDINATOR_DYNAMIC_CLIENT_CONF_ENABLED; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RemoteMergeShuffleWithRssClientTest extends ShuffleReadWriteBase { + + private static final int SHUFFLE_ID = 0; + private static final int PARTITION_ID = 0; + + private static ShuffleServerInfo shuffleServerInfo; + private ShuffleWriteClientImpl shuffleWriteClientImpl; + + @BeforeAll + public static void setupServers(@TempDir File tmpDir) throws Exception { + CoordinatorConf coordinatorConf = getCoordinatorConf(); + coordinatorConf.setBoolean(COORDINATOR_DYNAMIC_CLIENT_CONF_ENABLED, false); + createCoordinatorServer(coordinatorConf); + ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC); + shuffleServerConf.set(ShuffleServerConf.SERVER_MERGE_ENABLE, true); + shuffleServerConf.set(ShuffleServerConf.SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE, "1k"); + shuffleServerConf.setLong("rss.server.app.expired.withoutHeartbeat", 10000000); + File dataDir1 = new File(tmpDir, "data1"); + File dataDir2 = new File(tmpDir, "data2"); + String basePath = dataDir1.getAbsolutePath() + "," + dataDir2.getAbsolutePath(); + shuffleServerConf.setString("rss.storage.type", StorageType.LOCALFILE.name()); + shuffleServerConf.setString("rss.storage.basePath", basePath); + List ports = findAvailablePorts(2); + shuffleServerConf.setInteger("rss.rpc.server.port", ports.get(0)); + shuffleServerConf.setInteger("rss.jetty.http.port", ports.get(1)); + createShuffleServer(shuffleServerConf); + startServers(); + shuffleServerInfo = + new ShuffleServerInfo("127.0.0.1-20001", grpcShuffleServers.get(0).getIp(), ports.get(0)); + } + + private static List findAvailablePorts(int num) throws IOException { + List sockets = new ArrayList<>(); + List ports = new ArrayList<>(); + + for (int i = 0; i < num; i++) { + ServerSocket socket = new ServerSocket(0); + ports.add(socket.getLocalPort()); + sockets.add(socket); + } + + for (ServerSocket socket : sockets) { + socket.close(); + } + + return ports; + } + + @BeforeEach + public void createClient() { + shuffleWriteClientImpl = + new ShuffleWriteClientImpl( + ShuffleClientFactory.newWriteBuilder() + .clientType(ClientType.GRPC.name()) + .retryMax(3) + .retryIntervalMax(1000) + .heartBeatThreadNum(1) + .replica(1) + .replicaWrite(1) + .replicaRead(1) + .replicaSkipEnabled(true) + .dataTransferPoolSize(1) + .dataCommitPoolSize(1) + .unregisterThreadPoolSize(10) + .unregisterRequestTimeSec(10)); + } + + @AfterEach + public void closeClient() { + shuffleWriteClientImpl.close(); + } + + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + }) + @Timeout(10) + public void remoteMergeWriteReadTest(String classes) throws Exception { + // 1 basic parameter + final String[] classArray = classes.split(","); + final String keyClassName = classArray[0]; + final String valueClassName = classArray[1]; + final Class keyClass = SerializerUtils.getClassByName(keyClassName); + final Class valueClass = SerializerUtils.getClassByName(valueClassName); + final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final Comparator comparator = SerializerUtils.getComparator(keyClass); + final RssConf rssConf = new RssConf(); + + // 2 register shuffle + String testAppId = "remoteMergeWriteReadTest" + classes; + shuffleWriteClientImpl.registerShuffle( + shuffleServerInfo, + testAppId, + SHUFFLE_ID, + Lists.newArrayList(new PartitionRange(0, 0)), + new RemoteStorageInfo(""), + ShuffleDataDistributionType.NORMAL, + 0, + -1, + keyClass.getName(), + valueClass.getName(), + comparator.getClass().getName(), + -1, + null); + + // 3 report shuffle result + // task 0 attempt 0 generate three blocks + BlockIdLayout layout = BlockIdLayout.from(rssConf); + List blocks1 = new ArrayList<>(); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 0, + 5, + 1009, + 1)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 2, + 5, + 1009, + 1)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 4, + 5, + 1009, + 1)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); + // task 1 attempt 0 generate two blocks + List blocks2 = new ArrayList<>(); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 1, + 5, + 1009, + 1)); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 3, + 5, + 1009, + 1)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); + Map> partitionToServers = + ImmutableMap.of(PARTITION_ID, Lists.newArrayList(shuffleServerInfo)); + + // 4 report shuffle result + Map> ptb = ImmutableMap.of(PARTITION_ID, new HashSet()); + ptb.get(PARTITION_ID) + .addAll(blocks1.stream().map(s -> s.getBlockId()).collect(Collectors.toList())); + ptb.get(PARTITION_ID) + .addAll(blocks2.stream().map(s -> s.getBlockId()).collect(Collectors.toList())); + Map>> serverToPartitionToBlockIds = new HashMap(); + serverToPartitionToBlockIds.put(shuffleServerInfo, ptb); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 0, 1); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 1, 1); + + // 5 report unique blocks + Roaring64NavigableMap uniqueBlockIds = Roaring64NavigableMap.bitmapOf(); + ptb.get(PARTITION_ID).stream().forEach(block -> uniqueBlockIds.add(block)); + shuffleWriteClientImpl.startSortMerge( + Sets.newHashSet(shuffleServerInfo), testAppId, SHUFFLE_ID, PARTITION_ID, uniqueBlockIds); + + // 6 read result + Map> serverInfoMap = + ImmutableMap.of(PARTITION_ID, Lists.newArrayList(shuffleServerInfo)); + RMRecordsReader reader = + new RMRecordsReader( + testAppId, + SHUFFLE_ID, + Sets.newHashSet(PARTITION_ID), + serverInfoMap, + rssConf, + keyClass, + valueClass, + comparator, + raw, + null, + false, + null); + reader.start(); + int index = 0; + KeyValueReader keyValueReader = reader.keyValueReader(); + while (keyValueReader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), keyValueReader.getCurrentKey()); + assertEquals(SerializerUtils.genData(valueClass, index), keyValueReader.getCurrentValue()); + index++; + } + assertEquals(5 * 1009, index); + shuffleWriteClientImpl.unregisterShuffle(testAppId); + } + + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + }) + @Timeout(10) + public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception { + // 1 basic parameter + final String[] classArray = classes.split(","); + final String keyClassName = classArray[0]; + final String valueClassName = classArray[1]; + final Class keyClass = SerializerUtils.getClassByName(keyClassName); + final Class valueClass = SerializerUtils.getClassByName(valueClassName); + final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final Comparator comparator = SerializerUtils.getComparator(keyClass); + final RssConf rssConf = new RssConf(); + SerializerFactory factory = new SerializerFactory(rssConf); + Serializer serializer = factory.getSerializer(keyClass); + SerializerInstance serializerInstance = serializer.newInstance(); + final Combiner combiner = new SumByKeyCombiner(raw, serializerInstance, keyClass, valueClass); + + // 2 register shuffle + String testAppId = "remoteMergeWriteReadTestWithCombine" + classes; + shuffleWriteClientImpl.registerShuffle( + shuffleServerInfo, + testAppId, + SHUFFLE_ID, + Lists.newArrayList(new PartitionRange(0, 0)), + new RemoteStorageInfo(""), + ShuffleDataDistributionType.NORMAL, + 0, + -1, + keyClass.getName(), + valueClass.getName(), + comparator.getClass().getName(), + -1, + null); + + // 3 report shuffle result + // task 0 attempt 0 generate three blocks + BlockIdLayout layout = BlockIdLayout.from(rssConf); + List blocks1 = new ArrayList<>(); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 0, + 3, + 1009, + 1)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 1, + 3, + 1009, + 1)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 2, + 3, + 1009, + 1)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); + // task 1 attempt 0 generate two blocks + List blocks2 = new ArrayList<>(); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 0, + 3, + 1009, + 1)); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 2, + 3, + 1009, + 1)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); + Map> partitionToServers = + ImmutableMap.of(PARTITION_ID, Lists.newArrayList(shuffleServerInfo)); + + // 4 report shuffle result + Map> ptb = ImmutableMap.of(PARTITION_ID, new HashSet()); + ptb.get(PARTITION_ID) + .addAll(blocks1.stream().map(s -> s.getBlockId()).collect(Collectors.toList())); + ptb.get(PARTITION_ID) + .addAll(blocks2.stream().map(s -> s.getBlockId()).collect(Collectors.toList())); + Map>> serverToPartitionToBlockIds = new HashMap(); + serverToPartitionToBlockIds.put(shuffleServerInfo, ptb); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 0, 1); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 1, 1); + + // 5 report unique blocks + Roaring64NavigableMap uniqueBlockIds = Roaring64NavigableMap.bitmapOf(); + ptb.get(PARTITION_ID).stream().forEach(block -> uniqueBlockIds.add(block)); + shuffleWriteClientImpl.startSortMerge( + Sets.newHashSet(shuffleServerInfo), testAppId, SHUFFLE_ID, PARTITION_ID, uniqueBlockIds); + + // 6 read result + Map> serverInfoMap = + ImmutableMap.of(PARTITION_ID, Lists.newArrayList(shuffleServerInfo)); + RMRecordsReader reader = + new RMRecordsReader( + testAppId, + SHUFFLE_ID, + Sets.newHashSet(PARTITION_ID), + serverInfoMap, + rssConf, + keyClass, + valueClass, + comparator, + raw, + combiner, + false, + null); + reader.start(); + int index = 0; + KeyValueReader keyValueReader = reader.keyValueReader(); + while (keyValueReader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), keyValueReader.getCurrentKey()); + Object value = SerializerUtils.genData(valueClass, index); + Object newValue = value; + if (index % 3 != 1) { + if (value instanceof IntWritable) { + newValue = new IntWritable(((IntWritable) value).get() * 2); + } else { + newValue = (int) value * 2; + } + } + assertEquals(newValue, keyValueReader.getCurrentValue()); + index++; + } + assertEquals(3 * 1009, index); + shuffleWriteClientImpl.unregisterShuffle(testAppId); + } + + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + }) + @Timeout(10) + public void remoteMergeWriteReadTestMultiPartition(String classes) throws Exception { + // 1 basic parameter + final String[] classArray = classes.split(","); + final String keyClassName = classArray[0]; + final String valueClassName = classArray[1]; + final Class keyClass = SerializerUtils.getClassByName(keyClassName); + final Class valueClass = SerializerUtils.getClassByName(valueClassName); + final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final Comparator comparator = SerializerUtils.getComparator(keyClass); + final RssConf rssConf = new RssConf(); + + // 2 register shuffle + String testAppId = "remoteMergeWriteReadTestMultiPartition" + classes; + shuffleWriteClientImpl.registerShuffle( + shuffleServerInfo, + testAppId, + SHUFFLE_ID, + Lists.newArrayList( + new PartitionRange(PARTITION_ID, PARTITION_ID), + new PartitionRange(PARTITION_ID + 1, PARTITION_ID + 1), + new PartitionRange(PARTITION_ID + 2, PARTITION_ID + 2)), + new RemoteStorageInfo(""), + ShuffleDataDistributionType.NORMAL, + 0, + -1, + keyClass.getName(), + valueClass.getName(), + comparator.getClass().getName(), + -1, + null); + + // 3 report shuffle result + // this shuffle have three partition, which is hash by key index mode 3 + // task 0 attempt 0 generate three blocks + BlockIdLayout layout = BlockIdLayout.from(rssConf); + List blocks1 = new ArrayList<>(); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + 0, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 0, + 6, + 1009, + 1)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + 2, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 2, + 6, + 1009, + 1)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + 1, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 4, + 6, + 1009, + 1)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); + // task 1 attempt 0 generate two blocks + List blocks2 = new ArrayList<>(); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + 1, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 1, + 6, + 1009, + 1)); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + 0, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 3, + 6, + 1009, + 1)); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + 2, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 5, + 6, + 1009, + 1)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); + Map> partitionToServers = + ImmutableMap.of( + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 1, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 2, + Lists.newArrayList(shuffleServerInfo)); + + // 4 report shuffle result + Map> ptb = new HashMap<>(); + for (int i = PARTITION_ID; i < PARTITION_ID + 3; i++) { + final int partitionId = i; + ptb.put(partitionId, new HashSet()); + ptb.get(partitionId) + .addAll( + blocks1.stream() + .filter(s -> s.getPartitionId() == partitionId) + .map(s -> s.getBlockId()) + .collect(Collectors.toList())); + ptb.get(partitionId) + .addAll( + blocks2.stream() + .filter(s -> s.getPartitionId() == partitionId) + .map(s -> s.getBlockId()) + .collect(Collectors.toList())); + } + Map>> serverToPartitionToBlockIds = new HashMap(); + serverToPartitionToBlockIds.put(shuffleServerInfo, ptb); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 0, 1); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 1, 1); + + // 5 report unique blocks + for (int i = PARTITION_ID; i < PARTITION_ID + 3; i++) { + Roaring64NavigableMap uniqueBlockIds = Roaring64NavigableMap.bitmapOf(); + ptb.get(i).stream().forEach(block -> uniqueBlockIds.add(block)); + shuffleWriteClientImpl.startSortMerge( + Sets.newHashSet(shuffleServerInfo), testAppId, SHUFFLE_ID, i, uniqueBlockIds); + } + + // 6 read result + Map> serverInfoMap = + ImmutableMap.of( + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 1, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 2, + Lists.newArrayList(shuffleServerInfo)); + RMRecordsReader reader = + new RMRecordsReader( + testAppId, + SHUFFLE_ID, + Sets.newHashSet(PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2), + serverInfoMap, + rssConf, + keyClass, + valueClass, + comparator, + raw, + null, + false, + null); + reader.start(); + int index = 0; + KeyValueReader keyValueReader = reader.keyValueReader(); + while (keyValueReader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), keyValueReader.getCurrentKey()); + assertEquals(SerializerUtils.genData(valueClass, index), keyValueReader.getCurrentValue()); + index++; + } + assertEquals(6 * 1009, index); + shuffleWriteClientImpl.unregisterShuffle(testAppId); + } + + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + }) + @Timeout(10) + public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) throws Exception { + // 1 basic parameter + final String[] classArray = classes.split(","); + final String keyClassName = classArray[0]; + final String valueClassName = classArray[1]; + final Class keyClass = SerializerUtils.getClassByName(keyClassName); + final Class valueClass = SerializerUtils.getClassByName(valueClassName); + final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final Comparator comparator = SerializerUtils.getComparator(keyClass); + final RssConf rssConf = new RssConf(); + SerializerFactory factory = new SerializerFactory(rssConf); + Serializer serializer = factory.getSerializer(keyClass); + SerializerInstance serializerInstance = serializer.newInstance(); + final Combiner combiner = new SumByKeyCombiner(raw, serializerInstance, keyClass, valueClass); + + // 2 register shuffle + String testAppId = "remoteMergeWriteReadTestMultiPartitionWithCombine" + classes; + shuffleWriteClientImpl.registerShuffle( + shuffleServerInfo, + testAppId, + SHUFFLE_ID, + Lists.newArrayList( + new PartitionRange(PARTITION_ID, PARTITION_ID), + new PartitionRange(PARTITION_ID + 1, PARTITION_ID + 1), + new PartitionRange(PARTITION_ID + 2, PARTITION_ID + 2)), + new RemoteStorageInfo(""), + ShuffleDataDistributionType.NORMAL, + 0, + -1, + keyClass.getName(), + valueClass.getName(), + comparator.getClass().getName(), + -1, + null); + + // 3 report shuffle result + // this shuffle have three partition, which is hash by key index mode 3 + // task 0 attempt 0 generate three blocks + BlockIdLayout layout = BlockIdLayout.from(rssConf); + List blocks1 = new ArrayList<>(); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + 0, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 0, + 6, + 1009, + 2)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + 2, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 2, + 6, + 1009, + 2)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + 1, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 4, + 6, + 1009, + 2)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); + // task 1 attempt 0 generate two blocks + List blocks2 = new ArrayList<>(); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + 1, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 1, + 6, + 1009, + 2)); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + 0, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 3, + 6, + 1009, + 2)); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + 2, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 5, + 6, + 1009, + 2)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); + Map> partitionToServers = + ImmutableMap.of( + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 1, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 2, + Lists.newArrayList(shuffleServerInfo)); + + // 4 report shuffle result + Map> ptb = new HashMap<>(); + for (int i = PARTITION_ID; i < PARTITION_ID + 3; i++) { + final int partitionId = i; + ptb.put(partitionId, new HashSet()); + ptb.get(partitionId) + .addAll( + blocks1.stream() + .filter(s -> s.getPartitionId() == partitionId) + .map(s -> s.getBlockId()) + .collect(Collectors.toList())); + ptb.get(partitionId) + .addAll( + blocks2.stream() + .filter(s -> s.getPartitionId() == partitionId) + .map(s -> s.getBlockId()) + .collect(Collectors.toList())); + } + Map>> serverToPartitionToBlockIds = new HashMap(); + serverToPartitionToBlockIds.put(shuffleServerInfo, ptb); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 0, 1); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 1, 1); + + // 5 report unique blocks + for (int i = PARTITION_ID; i < PARTITION_ID + 3; i++) { + Roaring64NavigableMap uniqueBlockIds = Roaring64NavigableMap.bitmapOf(); + ptb.get(i).stream().forEach(block -> uniqueBlockIds.add(block)); + shuffleWriteClientImpl.startSortMerge( + new HashSet<>(partitionToServers.get(i)), testAppId, SHUFFLE_ID, i, uniqueBlockIds); + } + + // 6 read result + Map> serverInfoMap = + ImmutableMap.of( + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 1, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 2, + Lists.newArrayList(shuffleServerInfo)); + RMRecordsReader reader = + new RMRecordsReader( + testAppId, + SHUFFLE_ID, + Sets.newHashSet(PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2), + serverInfoMap, + rssConf, + keyClass, + valueClass, + comparator, + raw, + combiner, + false, + null); + reader.start(); + int index = 0; + KeyValueReader keyValueReader = reader.keyValueReader(); + while (keyValueReader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), keyValueReader.getCurrentKey()); + assertEquals( + SerializerUtils.genData(valueClass, index * 2), keyValueReader.getCurrentValue()); + index++; + } + assertEquals(6 * 1009, index); + shuffleWriteClientImpl.unregisterShuffle(testAppId); + } + + private static final AtomicInteger ATOMIC_INT_SORTED = new AtomicInteger(0); + + public static ShuffleBlockInfo createShuffleBlockForRemoteMerge( + RssConf rssConf, + BlockIdLayout blockIdLayout, + int taskAttemptId, + int partitionId, + List shuffleServerInfoList, + Class keyClass, + Class valueClass, + int start, + int interval, + int samples, + int duplicated) + throws IOException { + long blockId = + blockIdLayout.getBlockId(ATOMIC_INT_SORTED.getAndIncrement(), PARTITION_ID, taskAttemptId); + byte[] buf = + SerializerUtils.genSortedRecordBytes( + rssConf, keyClass, valueClass, start, interval, samples, duplicated); + return new ShuffleBlockInfo( + SHUFFLE_ID, + partitionId, + blockId, + buf.length, + ChecksumUtils.getCrc32(buf), + buf, + shuffleServerInfoList, + buf.length, + 0, + taskAttemptId); + } +} diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java new file mode 100644 index 0000000000..28d7cc10a0 --- /dev/null +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java @@ -0,0 +1,924 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.test; + +import java.io.File; +import java.io.IOException; +import java.net.ServerSocket; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.apache.hadoop.io.IntWritable; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.roaringbitmap.longlong.Roaring64NavigableMap; + +import org.apache.uniffle.client.factory.ShuffleClientFactory; +import org.apache.uniffle.client.impl.ShuffleWriteClientImpl; +import org.apache.uniffle.client.record.reader.KeyValueReader; +import org.apache.uniffle.client.record.reader.RMRecordsReader; +import org.apache.uniffle.client.record.writer.Combiner; +import org.apache.uniffle.client.record.writer.SumByKeyCombiner; +import org.apache.uniffle.common.ClientType; +import org.apache.uniffle.common.PartitionRange; +import org.apache.uniffle.common.RemoteStorageInfo; +import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.ShuffleDataDistributionType; +import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.rpc.ServerType; +import org.apache.uniffle.common.serializer.Serializer; +import org.apache.uniffle.common.serializer.SerializerFactory; +import org.apache.uniffle.common.serializer.SerializerInstance; +import org.apache.uniffle.common.serializer.SerializerUtils; +import org.apache.uniffle.common.util.BlockIdLayout; +import org.apache.uniffle.common.util.ChecksumUtils; +import org.apache.uniffle.coordinator.CoordinatorConf; +import org.apache.uniffle.server.ShuffleServerConf; +import org.apache.uniffle.storage.util.StorageType; + +import static org.apache.uniffle.coordinator.CoordinatorConf.COORDINATOR_DYNAMIC_CLIENT_CONF_ENABLED; +import static org.apache.uniffle.server.ShuffleServerConf.SERVER_MEMORY_SHUFFLE_HIGHWATERMARK_PERCENTAGE; +import static org.apache.uniffle.server.ShuffleServerConf.SERVER_MEMORY_SHUFFLE_LOWWATERMARK_PERCENTAGE; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends ShuffleReadWriteBase { + + private static final int SHUFFLE_ID = 0; + private static final int PARTITION_ID = 0; + + private static ShuffleServerInfo shuffleServerInfo; + private ShuffleWriteClientImpl shuffleWriteClientImpl; + + @BeforeAll + public static void setupServers(@TempDir File tmpDir) throws Exception { + CoordinatorConf coordinatorConf = getCoordinatorConf(); + coordinatorConf.setBoolean(COORDINATOR_DYNAMIC_CLIENT_CONF_ENABLED, false); + createCoordinatorServer(coordinatorConf); + ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC); + shuffleServerConf.set(ShuffleServerConf.SERVER_MERGE_ENABLE, true); + shuffleServerConf.set(ShuffleServerConf.SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE, "1k"); + // Each shuffle data will be flushed! + shuffleServerConf.set(SERVER_MEMORY_SHUFFLE_HIGHWATERMARK_PERCENTAGE, 0.0); + shuffleServerConf.set(SERVER_MEMORY_SHUFFLE_LOWWATERMARK_PERCENTAGE, 0.0); + shuffleServerConf.setLong("rss.server.app.expired.withoutHeartbeat", 10000000); + File dataDir1 = new File(tmpDir, "data1"); + File dataDir2 = new File(tmpDir, "data2"); + String basePath = dataDir1.getAbsolutePath() + "," + dataDir2.getAbsolutePath(); + shuffleServerConf.setString("rss.storage.type", StorageType.LOCALFILE.name()); + shuffleServerConf.setString("rss.storage.basePath", basePath); + List ports = findAvailablePorts(2); + shuffleServerConf.setInteger("rss.rpc.server.port", ports.get(0)); + shuffleServerConf.setInteger("rss.jetty.http.port", ports.get(1)); + createShuffleServer(shuffleServerConf); + startServers(); + shuffleServerInfo = + new ShuffleServerInfo("127.0.0.1-20001", grpcShuffleServers.get(0).getIp(), ports.get(0)); + } + + private static List findAvailablePorts(int num) throws IOException { + List sockets = new ArrayList<>(); + List ports = new ArrayList<>(); + + for (int i = 0; i < num; i++) { + ServerSocket socket = new ServerSocket(0); + ports.add(socket.getLocalPort()); + sockets.add(socket); + } + + for (ServerSocket socket : sockets) { + socket.close(); + } + + return ports; + } + + @BeforeEach + public void createClient() { + shuffleWriteClientImpl = + new ShuffleWriteClientImpl( + ShuffleClientFactory.newWriteBuilder() + .clientType(ClientType.GRPC.name()) + .retryMax(3) + .retryIntervalMax(1000) + .heartBeatThreadNum(1) + .replica(1) + .replicaWrite(1) + .replicaRead(1) + .replicaSkipEnabled(true) + .dataTransferPoolSize(1) + .dataCommitPoolSize(1) + .unregisterThreadPoolSize(10) + .unregisterRequestTimeSec(10)); + } + + @AfterEach + public void closeClient() { + shuffleWriteClientImpl.close(); + } + + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + }) + @Timeout(10) + public void remoteMergeWriteReadTest(String classes) throws Exception { + // 1 basic parameter + final String[] classArray = classes.split(","); + final String keyClassName = classArray[0]; + final String valueClassName = classArray[1]; + final Class keyClass = SerializerUtils.getClassByName(keyClassName); + final Class valueClass = SerializerUtils.getClassByName(valueClassName); + final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final Comparator comparator = SerializerUtils.getComparator(keyClass); + final RssConf rssConf = new RssConf(); + + // 2 register shuffle + String testAppId = "remoteMergeWriteReadTest" + classes; + shuffleWriteClientImpl.registerShuffle( + shuffleServerInfo, + testAppId, + SHUFFLE_ID, + Lists.newArrayList(new PartitionRange(0, 0)), + new RemoteStorageInfo(""), + ShuffleDataDistributionType.NORMAL, + -1, + 0, + keyClass.getName(), + valueClass.getName(), + comparator.getClass().getName(), + -1, + null); + + // 3 report shuffle result + // task 0 attempt 0 generate three blocks + BlockIdLayout layout = BlockIdLayout.from(rssConf); + List blocks1 = new ArrayList<>(); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 0, + 5, + 1009, + 1)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 2, + 5, + 1009, + 1)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 4, + 5, + 1009, + 1)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); + // task 1 attempt 0 generate two blocks + List blocks2 = new ArrayList<>(); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 1, + 5, + 1009, + 1)); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 3, + 5, + 1009, + 1)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); + Map> partitionToServers = + ImmutableMap.of(PARTITION_ID, Lists.newArrayList(shuffleServerInfo)); + + // 4 report shuffle result + Map> ptb = ImmutableMap.of(PARTITION_ID, new HashSet()); + ptb.get(PARTITION_ID) + .addAll(blocks1.stream().map(s -> s.getBlockId()).collect(Collectors.toList())); + ptb.get(PARTITION_ID) + .addAll(blocks2.stream().map(s -> s.getBlockId()).collect(Collectors.toList())); + Map>> serverToPartitionToBlockIds = new HashMap(); + serverToPartitionToBlockIds.put(shuffleServerInfo, ptb); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 0, 1); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 1, 1); + + // 5 report unique blocks + Roaring64NavigableMap uniqueBlockIds = Roaring64NavigableMap.bitmapOf(); + ptb.get(PARTITION_ID).stream().forEach(block -> uniqueBlockIds.add(block)); + shuffleWriteClientImpl.startSortMerge( + Sets.newHashSet(shuffleServerInfo), testAppId, SHUFFLE_ID, PARTITION_ID, uniqueBlockIds); + + // 6 read result + Map> serverInfoMap = + ImmutableMap.of(PARTITION_ID, Lists.newArrayList(shuffleServerInfo)); + RMRecordsReader reader = + new RMRecordsReader( + testAppId, + SHUFFLE_ID, + Sets.newHashSet(PARTITION_ID), + serverInfoMap, + rssConf, + keyClass, + valueClass, + comparator, + raw, + null, + false, + null); + reader.start(); + int index = 0; + KeyValueReader keyValueReader = reader.keyValueReader(); + while (keyValueReader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), keyValueReader.getCurrentKey()); + assertEquals(SerializerUtils.genData(valueClass, index), keyValueReader.getCurrentValue()); + index++; + } + assertEquals(5 * 1009, index); + shuffleWriteClientImpl.unregisterShuffle(testAppId); + } + + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + }) + @Timeout(10) + public void remoteMergeWriteReadTestWithCombine(String classes) throws Exception { + // 1 basic parameter + final String[] classArray = classes.split(","); + final String keyClassName = classArray[0]; + final String valueClassName = classArray[1]; + final Class keyClass = SerializerUtils.getClassByName(keyClassName); + final Class valueClass = SerializerUtils.getClassByName(valueClassName); + final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final Comparator comparator = SerializerUtils.getComparator(keyClass); + final RssConf rssConf = new RssConf(); + SerializerFactory factory = new SerializerFactory(rssConf); + Serializer serializer = factory.getSerializer(keyClass); + SerializerInstance serializerInstance = serializer.newInstance(); + final Combiner combiner = new SumByKeyCombiner(raw, serializerInstance, keyClass, valueClass); + + // 2 register shuffle + String testAppId = "remoteMergeWriteReadTestWithCombine" + classes; + shuffleWriteClientImpl.registerShuffle( + shuffleServerInfo, + testAppId, + SHUFFLE_ID, + Lists.newArrayList(new PartitionRange(0, 0)), + new RemoteStorageInfo(""), + ShuffleDataDistributionType.NORMAL, + -1, + 0, + keyClass.getName(), + valueClass.getName(), + comparator.getClass().getName(), + -1, + null); + Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); + + // 3 report shuffle result + // task 0 attempt 0 generate three blocks + BlockIdLayout layout = BlockIdLayout.from(rssConf); + List blocks1 = new ArrayList<>(); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 0, + 3, + 1009, + 1)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 1, + 3, + 1009, + 1)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 2, + 3, + 1009, + 1)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); + // task 1 attempt 0 generate two blocks + List blocks2 = new ArrayList<>(); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 0, + 3, + 1009, + 1)); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 2, + 3, + 1009, + 1)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); + Map> partitionToServers = + ImmutableMap.of(PARTITION_ID, Lists.newArrayList(shuffleServerInfo)); + + // 4 report shuffle result + Map> ptb = ImmutableMap.of(PARTITION_ID, new HashSet()); + ptb.get(PARTITION_ID) + .addAll(blocks1.stream().map(s -> s.getBlockId()).collect(Collectors.toList())); + ptb.get(PARTITION_ID) + .addAll(blocks2.stream().map(s -> s.getBlockId()).collect(Collectors.toList())); + Map>> serverToPartitionToBlockIds = new HashMap(); + serverToPartitionToBlockIds.put(shuffleServerInfo, ptb); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 0, 1); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 1, 1); + + // 5 report unique blocks + Roaring64NavigableMap uniqueBlockIds = Roaring64NavigableMap.bitmapOf(); + ptb.get(PARTITION_ID).stream().forEach(block -> uniqueBlockIds.add(block)); + shuffleWriteClientImpl.startSortMerge( + Sets.newHashSet(shuffleServerInfo), testAppId, SHUFFLE_ID, PARTITION_ID, uniqueBlockIds); + + // 6 read result + Map> serverInfoMap = + ImmutableMap.of(PARTITION_ID, Lists.newArrayList(shuffleServerInfo)); + RMRecordsReader reader = + new RMRecordsReader( + testAppId, + SHUFFLE_ID, + Sets.newHashSet(PARTITION_ID), + serverInfoMap, + rssConf, + keyClass, + valueClass, + comparator, + raw, + combiner, + false, + null); + reader.start(); + int index = 0; + KeyValueReader keyValueReader = reader.keyValueReader(); + while (keyValueReader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), keyValueReader.getCurrentKey()); + Object value = SerializerUtils.genData(valueClass, index); + Object newValue = value; + if (index % 3 != 1) { + if (value instanceof IntWritable) { + newValue = new IntWritable(((IntWritable) value).get() * 2); + } else { + newValue = (int) value * 2; + } + } + assertEquals(newValue, keyValueReader.getCurrentValue()); + index++; + } + assertEquals(3 * 1009, index); + shuffleWriteClientImpl.unregisterShuffle(testAppId); + } + + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + }) + @Timeout(10) + public void remoteMergeWriteReadTestMultiPartition(String classes) throws Exception { + // 1 basic parameter + final String[] classArray = classes.split(","); + final String keyClassName = classArray[0]; + final String valueClassName = classArray[1]; + final Class keyClass = SerializerUtils.getClassByName(keyClassName); + final Class valueClass = SerializerUtils.getClassByName(valueClassName); + final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final Comparator comparator = SerializerUtils.getComparator(keyClass); + final RssConf rssConf = new RssConf(); + + // 2 register shuffle + String testAppId = "remoteMergeWriteReadTestMultiPartition" + classes; + shuffleWriteClientImpl.registerShuffle( + shuffleServerInfo, + testAppId, + SHUFFLE_ID, + Lists.newArrayList( + new PartitionRange(PARTITION_ID, PARTITION_ID), + new PartitionRange(PARTITION_ID + 1, PARTITION_ID + 1), + new PartitionRange(PARTITION_ID + 2, PARTITION_ID + 2)), + new RemoteStorageInfo(""), + ShuffleDataDistributionType.NORMAL, + -1, + 0, + keyClass.getName(), + valueClass.getName(), + comparator.getClass().getName(), + -1, + null); + + // 3 report shuffle result + // this shuffle have three partition, which is hash by key index mode 3 + // task 0 attempt 0 generate three blocks + BlockIdLayout layout = BlockIdLayout.from(rssConf); + List blocks1 = new ArrayList<>(); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + 0, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 0, + 6, + 1009, + 1)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + 2, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 2, + 6, + 1009, + 1)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + 1, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 4, + 6, + 1009, + 1)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); + // task 1 attempt 0 generate two blocks + List blocks2 = new ArrayList<>(); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + 1, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 1, + 6, + 1009, + 1)); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + 0, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 3, + 6, + 1009, + 1)); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + 2, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 5, + 6, + 1009, + 1)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); + Map> partitionToServers = + ImmutableMap.of( + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 1, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 2, + Lists.newArrayList(shuffleServerInfo)); + + // 4 report shuffle result + Map> ptb = new HashMap<>(); + for (int i = PARTITION_ID; i < PARTITION_ID + 3; i++) { + final int partitionId = i; + ptb.put(partitionId, new HashSet()); + ptb.get(partitionId) + .addAll( + blocks1.stream() + .filter(s -> s.getPartitionId() == partitionId) + .map(s -> s.getBlockId()) + .collect(Collectors.toList())); + ptb.get(partitionId) + .addAll( + blocks2.stream() + .filter(s -> s.getPartitionId() == partitionId) + .map(s -> s.getBlockId()) + .collect(Collectors.toList())); + } + Map>> serverToPartitionToBlockIds = new HashMap(); + serverToPartitionToBlockIds.put(shuffleServerInfo, ptb); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 0, 1); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 1, 1); + + // 5 report unique blocks + for (int i = PARTITION_ID; i < PARTITION_ID + 3; i++) { + Roaring64NavigableMap uniqueBlockIds = Roaring64NavigableMap.bitmapOf(); + ptb.get(i).stream().forEach(block -> uniqueBlockIds.add(block)); + shuffleWriteClientImpl.startSortMerge( + Sets.newHashSet(shuffleServerInfo), testAppId, SHUFFLE_ID, i, uniqueBlockIds); + } + + // 6 read result + Map> serverInfoMap = + ImmutableMap.of( + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 1, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 2, + Lists.newArrayList(shuffleServerInfo)); + RMRecordsReader reader = + new RMRecordsReader( + testAppId, + SHUFFLE_ID, + Sets.newHashSet(PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2), + serverInfoMap, + rssConf, + keyClass, + valueClass, + comparator, + raw, + null, + false, + null); + reader.start(); + int index = 0; + KeyValueReader keyValueReader = reader.keyValueReader(); + while (keyValueReader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), keyValueReader.getCurrentKey()); + assertEquals(SerializerUtils.genData(valueClass, index), keyValueReader.getCurrentValue()); + index++; + } + assertEquals(6 * 1009, index); + shuffleWriteClientImpl.unregisterShuffle(testAppId); + } + + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,true", + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable,false", + }) + @Timeout(10) + public void remoteMergeWriteReadTestMultiPartitionWithCombine(String classes) throws Exception { + // 1 basic parameter + final String[] classArray = classes.split(","); + final String keyClassName = classArray[0]; + final String valueClassName = classArray[1]; + final Class keyClass = SerializerUtils.getClassByName(keyClassName); + final Class valueClass = SerializerUtils.getClassByName(valueClassName); + final boolean raw = classArray.length > 2 ? Boolean.parseBoolean(classArray[2]) : false; + final Comparator comparator = SerializerUtils.getComparator(keyClass); + final RssConf rssConf = new RssConf(); + SerializerFactory factory = new SerializerFactory(rssConf); + Serializer serializer = factory.getSerializer(keyClass); + SerializerInstance serializerInstance = serializer.newInstance(); + final Combiner combiner = new SumByKeyCombiner(raw, serializerInstance, keyClass, valueClass); + + // 2 register shuffle + String testAppId = "remoteMergeWriteReadTestMultiPartitionWithCombine" + classes; + shuffleWriteClientImpl.registerShuffle( + shuffleServerInfo, + testAppId, + SHUFFLE_ID, + Lists.newArrayList( + Lists.newArrayList( + new PartitionRange(PARTITION_ID, PARTITION_ID), + new PartitionRange(PARTITION_ID + 1, PARTITION_ID + 1), + new PartitionRange(PARTITION_ID + 2, PARTITION_ID + 2))), + new RemoteStorageInfo(""), + ShuffleDataDistributionType.NORMAL, + -1, + 0, + keyClass.getName(), + valueClass.getName(), + comparator.getClass().getName(), + -1, + null); + + // 3 report shuffle result + // this shuffle have three partition, which is hash by key index mode 3 + // task 0 attempt 0 generate three blocks + BlockIdLayout layout = BlockIdLayout.from(rssConf); + List blocks1 = new ArrayList<>(); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + 0, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 0, + 6, + 1009, + 2)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + 2, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 2, + 6, + 1009, + 2)); + blocks1.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 0, + 1, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 4, + 6, + 1009, + 2)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks1, () -> false); + // task 1 attempt 0 generate two blocks + List blocks2 = new ArrayList<>(); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + 1, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 1, + 6, + 1009, + 2)); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + 0, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 3, + 6, + 1009, + 2)); + blocks2.add( + createShuffleBlockForRemoteMerge( + rssConf, + layout, + 1, + 2, + Lists.newArrayList(shuffleServerInfo), + keyClass, + valueClass, + 5, + 6, + 1009, + 2)); + shuffleWriteClientImpl.sendShuffleData(testAppId, blocks2, () -> false); + Map> partitionToServers = + ImmutableMap.of( + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 1, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 2, + Lists.newArrayList(shuffleServerInfo)); + + // 4 report shuffle result + Map> ptb = new HashMap<>(); + for (int i = PARTITION_ID; i < PARTITION_ID + 3; i++) { + final int partitionId = i; + ptb.put(partitionId, new HashSet()); + ptb.get(partitionId) + .addAll( + blocks1.stream() + .filter(s -> s.getPartitionId() == partitionId) + .map(s -> s.getBlockId()) + .collect(Collectors.toList())); + ptb.get(partitionId) + .addAll( + blocks2.stream() + .filter(s -> s.getPartitionId() == partitionId) + .map(s -> s.getBlockId()) + .collect(Collectors.toList())); + } + Map>> serverToPartitionToBlockIds = new HashMap(); + serverToPartitionToBlockIds.put(shuffleServerInfo, ptb); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 0, 1); + shuffleWriteClientImpl.reportShuffleResult( + serverToPartitionToBlockIds, testAppId, SHUFFLE_ID, 1, 1); + + // 5 report unique blocks + for (int i = PARTITION_ID; i < PARTITION_ID + 3; i++) { + Roaring64NavigableMap uniqueBlockIds = Roaring64NavigableMap.bitmapOf(); + ptb.get(i).stream().forEach(block -> uniqueBlockIds.add(block)); + shuffleWriteClientImpl.startSortMerge( + new HashSet<>(partitionToServers.get(i)), testAppId, SHUFFLE_ID, i, uniqueBlockIds); + } + + // 6 read result + Map> serverInfoMap = + ImmutableMap.of( + PARTITION_ID, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 1, + Lists.newArrayList(shuffleServerInfo), + PARTITION_ID + 2, + Lists.newArrayList(shuffleServerInfo)); + RMRecordsReader reader = + new RMRecordsReader( + testAppId, + SHUFFLE_ID, + Sets.newHashSet(PARTITION_ID, PARTITION_ID + 1, PARTITION_ID + 2), + serverInfoMap, + rssConf, + keyClass, + valueClass, + comparator, + raw, + combiner, + false, + null); + reader.start(); + int index = 0; + KeyValueReader keyValueReader = reader.keyValueReader(); + while (keyValueReader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), keyValueReader.getCurrentKey()); + assertEquals( + SerializerUtils.genData(valueClass, index * 2), keyValueReader.getCurrentValue()); + index++; + } + assertEquals(6 * 1009, index); + shuffleWriteClientImpl.unregisterShuffle(testAppId); + } + + private static final AtomicInteger ATOMIC_INT_SORTED = new AtomicInteger(0); + + public static ShuffleBlockInfo createShuffleBlockForRemoteMerge( + RssConf rssConf, + BlockIdLayout blockIdLayout, + int taskAttemptId, + int partitionId, + List shuffleServerInfoList, + Class keyClass, + Class valueClass, + int start, + int interval, + int samples, + int duplicated) + throws IOException { + long blockId = + blockIdLayout.getBlockId(ATOMIC_INT_SORTED.getAndIncrement(), PARTITION_ID, taskAttemptId); + byte[] buf = + SerializerUtils.genSortedRecordBytes( + rssConf, keyClass, valueClass, start, interval, samples, duplicated); + return new ShuffleBlockInfo( + SHUFFLE_ID, + partitionId, + blockId, + buf.length, + ChecksumUtils.getCrc32(buf), + buf, + shuffleServerInfoList, + buf.length, + 0, + taskAttemptId); + } +} diff --git a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleServerClient.java b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleServerClient.java index 60297ea086..3e01d67600 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleServerClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleServerClient.java @@ -24,10 +24,12 @@ import org.apache.uniffle.client.request.RssGetShuffleIndexRequest; import org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest; import org.apache.uniffle.client.request.RssGetShuffleResultRequest; +import org.apache.uniffle.client.request.RssGetSortedShuffleDataRequest; import org.apache.uniffle.client.request.RssRegisterShuffleRequest; import org.apache.uniffle.client.request.RssReportShuffleResultRequest; import org.apache.uniffle.client.request.RssSendCommitRequest; import org.apache.uniffle.client.request.RssSendShuffleDataRequest; +import org.apache.uniffle.client.request.RssStartSortMergeRequest; import org.apache.uniffle.client.request.RssUnregisterShuffleByAppIdRequest; import org.apache.uniffle.client.request.RssUnregisterShuffleRequest; import org.apache.uniffle.client.response.RssAppHeartBeatResponse; @@ -36,10 +38,12 @@ import org.apache.uniffle.client.response.RssGetShuffleDataResponse; import org.apache.uniffle.client.response.RssGetShuffleIndexResponse; import org.apache.uniffle.client.response.RssGetShuffleResultResponse; +import org.apache.uniffle.client.response.RssGetSortedShuffleDataResponse; import org.apache.uniffle.client.response.RssRegisterShuffleResponse; import org.apache.uniffle.client.response.RssReportShuffleResultResponse; import org.apache.uniffle.client.response.RssSendCommitResponse; import org.apache.uniffle.client.response.RssSendShuffleDataResponse; +import org.apache.uniffle.client.response.RssStartSortMergeResponse; import org.apache.uniffle.client.response.RssUnregisterShuffleByAppIdResponse; import org.apache.uniffle.client.response.RssUnregisterShuffleResponse; @@ -74,6 +78,10 @@ RssGetShuffleResultResponse getShuffleResultForMultiPart( RssGetInMemoryShuffleDataResponse getInMemoryShuffleData( RssGetInMemoryShuffleDataRequest request); + RssStartSortMergeResponse startSortMerge(RssStartSortMergeRequest request); + + RssGetSortedShuffleDataResponse getSortedShuffleData(RssGetSortedShuffleDataRequest request); + void close(); String getClientInfo(); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index 246ae3e37f..63081041d6 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -31,6 +31,7 @@ import com.google.protobuf.ByteString; import com.google.protobuf.UnsafeByteOperations; import io.netty.buffer.Unpooled; +import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,10 +44,12 @@ import org.apache.uniffle.client.request.RssGetShuffleIndexRequest; import org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest; import org.apache.uniffle.client.request.RssGetShuffleResultRequest; +import org.apache.uniffle.client.request.RssGetSortedShuffleDataRequest; import org.apache.uniffle.client.request.RssRegisterShuffleRequest; import org.apache.uniffle.client.request.RssReportShuffleResultRequest; import org.apache.uniffle.client.request.RssSendCommitRequest; import org.apache.uniffle.client.request.RssSendShuffleDataRequest; +import org.apache.uniffle.client.request.RssStartSortMergeRequest; import org.apache.uniffle.client.request.RssUnregisterShuffleByAppIdRequest; import org.apache.uniffle.client.request.RssUnregisterShuffleRequest; import org.apache.uniffle.client.response.RssAppHeartBeatResponse; @@ -55,10 +58,12 @@ import org.apache.uniffle.client.response.RssGetShuffleDataResponse; import org.apache.uniffle.client.response.RssGetShuffleIndexResponse; import org.apache.uniffle.client.response.RssGetShuffleResultResponse; +import org.apache.uniffle.client.response.RssGetSortedShuffleDataResponse; import org.apache.uniffle.client.response.RssRegisterShuffleResponse; import org.apache.uniffle.client.response.RssReportShuffleResultResponse; import org.apache.uniffle.client.response.RssSendCommitResponse; import org.apache.uniffle.client.response.RssSendShuffleDataResponse; +import org.apache.uniffle.client.response.RssStartSortMergeResponse; import org.apache.uniffle.client.response.RssUnregisterShuffleByAppIdResponse; import org.apache.uniffle.client.response.RssUnregisterShuffleResponse; import org.apache.uniffle.common.BufferSegment; @@ -192,7 +197,12 @@ private ShuffleRegisterResponse doRegisterShuffle( String user, ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, - int stageAttemptNumber) { + int stageAttemptNumber, + String keyClassName, + String valueClassName, + String comparatorClassName, + int mergedBlockSize, + String mergeClassLoader) { ShuffleRegisterRequest.Builder reqBuilder = ShuffleRegisterRequest.newBuilder(); reqBuilder .setAppId(appId) @@ -202,6 +212,17 @@ private ShuffleRegisterResponse doRegisterShuffle( .setMaxConcurrencyPerPartitionToWrite(maxConcurrencyPerPartitionToWrite) .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges)) .setStageAttemptNumber(stageAttemptNumber); + if (StringUtils.isNotBlank(keyClassName)) { + reqBuilder.setKeyClass(keyClassName); + reqBuilder.setValueClass(valueClassName); + if (StringUtils.isNotBlank(comparatorClassName)) { + reqBuilder.setComparatorClass(comparatorClassName); + } + reqBuilder.setMergedBlockSize(mergedBlockSize); + if (StringUtils.isNotBlank(mergeClassLoader)) { + reqBuilder.setMergeClassLoader(mergeClassLoader); + } + } RemoteStorage.Builder rsBuilder = RemoteStorage.newBuilder(); rsBuilder.setPath(remoteStorageInfo.getPath()); Map remoteStorageConf = remoteStorageInfo.getConfItems(); @@ -474,7 +495,12 @@ public RssRegisterShuffleResponse registerShuffle(RssRegisterShuffleRequest requ request.getUser(), request.getDataDistributionType(), request.getMaxConcurrencyPerPartitionToWrite(), - request.getStageAttemptNumber()); + request.getStageAttemptNumber(), + request.getKeyClassName(), + request.getValueClassName(), + request.getComparatorClassName(), + request.getMergedBlockSize(), + request.getMergeClassLoader()); RssRegisterShuffleResponse response; RssProtos.StatusCode statusCode = rpcResponse.getStatus(); @@ -1091,6 +1117,122 @@ public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData( return response; } + @Override + public RssStartSortMergeResponse startSortMerge(RssStartSortMergeRequest request) { + ByteString serializedBlockIdsBytes = ByteString.EMPTY; + try { + if (request.getExpectedTaskIds() != null) { + serializedBlockIdsBytes = + UnsafeByteOperations.unsafeWrap(RssUtils.serializeBitMap(request.getExpectedTaskIds())); + } + } catch (Exception e) { + throw new RssException("Errors on serializing task ids bitmap.", e); + } + + RssProtos.StartSortMergeRequest rpcRequest = + RssProtos.StartSortMergeRequest.newBuilder() + .setAppId(request.getAppId()) + .setShuffleId(request.getShuffleId()) + .setPartitionId(request.getPartitionId()) + .setUniqueBlocksBitmap(serializedBlockIdsBytes) + .build(); + long start = System.currentTimeMillis(); + RssProtos.StartSortMergeResponse rpcResponse = getBlockingStub().startSortMerge(rpcRequest); + String requestInfo = + "appId[" + + request.getAppId() + + "], shuffleId[" + + request.getShuffleId() + + "], partitionId[" + + request.getPartitionId() + + "]"; + LOG.info( + "startSortMerge to {}:{} for {} cost {} ms", + host, + port, + requestInfo, + (System.currentTimeMillis() - start)); + RssProtos.StatusCode statusCode = rpcResponse.getStatus(); + RssStartSortMergeResponse response; + switch (statusCode) { + case SUCCESS: + response = new RssStartSortMergeResponse(StatusCode.SUCCESS); + break; + default: + String msg = + "Can't report unique block to from " + + host + + ":" + + port + + " for " + + requestInfo + + ", errorMsg:" + + rpcResponse.getRetMsg(); + LOG.error(msg); + throw new RssException(msg); + } + return response; + } + + @Override + public RssGetSortedShuffleDataResponse getSortedShuffleData( + RssGetSortedShuffleDataRequest request) { + long start = System.currentTimeMillis(); + RssProtos.GetSortedShuffleDataRequest rpcRequest = + RssProtos.GetSortedShuffleDataRequest.newBuilder() + .setAppId(request.getAppId()) + .setShuffleId(request.getShuffleId()) + .setPartitionId(request.getPartitionId()) + .setMergedBlockId(request.getBlockId()) + .setTimestamp(start) + .build(); + RssProtos.GetSortedShuffleDataResponse rpcResponse = + getBlockingStub().getSortedShuffleData(rpcRequest); + String requestInfo = + "appId[" + + request.getAppId() + + "], shuffleId[" + + request.getShuffleId() + + "], partitionId[" + + request.getPartitionId() + + "], blockId[" + + request.getBlockId() + + "]"; + LOG.info( + "GetSortedShuffleData from {}:{} for {} cost {} ms", + host, + port, + requestInfo, + System.currentTimeMillis() - start); + + RssProtos.StatusCode statusCode = rpcResponse.getStatus(); + + RssGetSortedShuffleDataResponse response; + switch (statusCode) { + case SUCCESS: + response = + new RssGetSortedShuffleDataResponse( + StatusCode.SUCCESS, + ByteBuffer.wrap(rpcResponse.getData().toByteArray()), + rpcResponse.getNextBlockId(), + rpcResponse.getMState()); + break; + default: + String msg = + "Can't get sorted shuffle data from " + + host + + ":" + + port + + " for " + + requestInfo + + ", errorMsg:" + + rpcResponse.getRetMsg(); + LOG.error(msg); + throw new RssFetchFailedException(msg); + } + return response; + } + @Override public String getClientInfo() { return "ShuffleServerGrpcClient for host[" + host + "], port[" + port + "]"; diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetSortedShuffleDataRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetSortedShuffleDataRequest.java new file mode 100644 index 0000000000..f3b4eb0789 --- /dev/null +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetSortedShuffleDataRequest.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.request; + +public class RssGetSortedShuffleDataRequest { + + private final String appId; + private final int shuffleId; + private final int partitionId; + private final long blockId; + + public RssGetSortedShuffleDataRequest( + String appId, int shuffleId, int partitionId, long blockId) { + this.appId = appId; + this.shuffleId = shuffleId; + this.partitionId = partitionId; + this.blockId = blockId; + } + + public String getAppId() { + return appId; + } + + public int getShuffleId() { + return shuffleId; + } + + public int getPartitionId() { + return partitionId; + } + + public long getBlockId() { + return blockId; + } +} diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java index 7e42be653e..1db40a0d1f 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java @@ -36,6 +36,11 @@ public class RssRegisterShuffleRequest { private ShuffleDataDistributionType dataDistributionType; private int maxConcurrencyPerPartitionToWrite; private int stageAttemptNumber; + private String keyClassName; + private String valueClassName; + private String comparatorClassName; + private int mergedBlockSize; + private String mergeClassLoader; public RssRegisterShuffleRequest( String appId, @@ -53,7 +58,12 @@ public RssRegisterShuffleRequest( user, dataDistributionType, maxConcurrencyPerPartitionToWrite, - 0); + 0, + null, + null, + null, + -1, + null); } public RssRegisterShuffleRequest( @@ -64,7 +74,12 @@ public RssRegisterShuffleRequest( String user, ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, - int stageAttemptNumber) { + int stageAttemptNumber, + String keyClassName, + String valueClassName, + String comparatorClassName, + int mergedBlockSize, + String mergeClassLoader) { this.appId = appId; this.shuffleId = shuffleId; this.partitionRanges = partitionRanges; @@ -73,6 +88,11 @@ public RssRegisterShuffleRequest( this.dataDistributionType = dataDistributionType; this.maxConcurrencyPerPartitionToWrite = maxConcurrencyPerPartitionToWrite; this.stageAttemptNumber = stageAttemptNumber; + this.keyClassName = keyClassName; + this.valueClassName = valueClassName; + this.comparatorClassName = comparatorClassName; + this.mergedBlockSize = mergedBlockSize; + this.mergeClassLoader = mergeClassLoader; } public RssRegisterShuffleRequest( @@ -90,7 +110,12 @@ public RssRegisterShuffleRequest( user, dataDistributionType, RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), - 0); + 0, + null, + null, + null, + -1, + null); } public RssRegisterShuffleRequest( @@ -103,7 +128,12 @@ public RssRegisterShuffleRequest( StringUtils.EMPTY, ShuffleDataDistributionType.NORMAL, RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), - 0); + 0, + null, + null, + null, + -1, + null); } public String getAppId() { @@ -137,4 +167,24 @@ public int getMaxConcurrencyPerPartitionToWrite() { public int getStageAttemptNumber() { return stageAttemptNumber; } + + public String getKeyClassName() { + return keyClassName; + } + + public String getValueClassName() { + return valueClassName; + } + + public String getComparatorClassName() { + return comparatorClassName; + } + + public int getMergedBlockSize() { + return mergedBlockSize; + } + + public String getMergeClassLoader() { + return mergeClassLoader; + } } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssStartSortMergeRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssStartSortMergeRequest.java new file mode 100644 index 0000000000..03a876c33c --- /dev/null +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssStartSortMergeRequest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.request; + +import org.roaringbitmap.longlong.Roaring64NavigableMap; + +public class RssStartSortMergeRequest { + + private String appId; + private int shuffleId; + private int partitionId; + private Roaring64NavigableMap expectedBlockIds; + + public RssStartSortMergeRequest( + String appId, int shuffleId, int partitionId, Roaring64NavigableMap expectedBlockIds) { + this.appId = appId; + this.shuffleId = shuffleId; + this.partitionId = partitionId; + this.expectedBlockIds = expectedBlockIds; + } + + public String getAppId() { + return appId; + } + + public int getShuffleId() { + return shuffleId; + } + + public int getPartitionId() { + return partitionId; + } + + public Roaring64NavigableMap getExpectedTaskIds() { + return expectedBlockIds; + } +} diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetSortedShuffleDataResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetSortedShuffleDataResponse.java new file mode 100644 index 0000000000..fa153e3d6e --- /dev/null +++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetSortedShuffleDataResponse.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.response; + +import java.nio.ByteBuffer; + +import org.apache.uniffle.common.rpc.StatusCode; + +public class RssGetSortedShuffleDataResponse extends ClientResponse { + + private final ByteBuffer data; + private final long nextBlockId; + private final int mergeState; + + public RssGetSortedShuffleDataResponse( + StatusCode statusCode, ByteBuffer data, long nextBlockId, int mergeState) { + super(statusCode); + this.data = data; + this.nextBlockId = nextBlockId; + this.mergeState = mergeState; + } + + public ByteBuffer getData() { + return data; + } + + public long getNextBlockId() { + return nextBlockId; + } + + public int getMergeState() { + return mergeState; + } +} diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssStartSortMergeResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssStartSortMergeResponse.java new file mode 100644 index 0000000000..82a6213a6d --- /dev/null +++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssStartSortMergeResponse.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.client.response; + +import org.apache.uniffle.common.rpc.StatusCode; + +public class RssStartSortMergeResponse extends ClientResponse { + + public RssStartSortMergeResponse(StatusCode statusCode) { + super(statusCode); + } +} diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index 0bf258ad9a..06d781e134 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -38,6 +38,8 @@ service ShuffleServer { rpc finishShuffle (FinishShuffleRequest) returns (FinishShuffleResponse); rpc requireBuffer (RequireBufferRequest) returns (RequireBufferResponse); rpc appHeartbeat(AppHeartBeatRequest) returns (AppHeartBeatResponse); + rpc startSortMerge (StartSortMergeRequest) returns (StartSortMergeResponse); + rpc getSortedShuffleData (GetSortedShuffleDataRequest) returns (GetSortedShuffleDataResponse); } message FinishShuffleRequest { @@ -186,6 +188,11 @@ message ShuffleRegisterRequest { DataDistribution shuffleDataDistribution = 6; int32 maxConcurrencyPerPartitionToWrite = 7; int32 stageAttemptNumber = 8; + string keyClass = 9; + string valueClass = 10; + string comparatorClass = 11; + int32 mergedBlockSize = 12; + string mergeClassLoader = 13; } enum DataDistribution { @@ -662,4 +669,30 @@ message ReassignOnBlockSendFailureResponse { MutableShuffleHandleInfo handle = 3; } +message StartSortMergeRequest { + string appId = 1; + int32 shuffleId = 2; + int32 partitionId = 3; + bytes uniqueBlocksBitmap = 4; +} + +message StartSortMergeResponse { + StatusCode status = 1; + string retMsg = 2; +} + +message GetSortedShuffleDataRequest { + string appId = 1; + int32 shuffleId = 2; + int32 partitionId = 3; + int64 mergedBlockId = 4; + int64 timestamp = 5; +} +message GetSortedShuffleDataResponse { + bytes data = 1; + StatusCode status = 2; + string retMsg = 3; + int64 nextBlockId = 4; + int32 mState = 5; +} diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index 5cce3a3a8b..e695163319 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -17,6 +17,7 @@ package org.apache.uniffle.server; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.Collections; import java.util.List; @@ -54,6 +55,7 @@ import org.apache.uniffle.common.exception.NoBufferException; import org.apache.uniffle.common.exception.NoBufferForHugePartitionException; import org.apache.uniffle.common.exception.NoRegisterException; +import org.apache.uniffle.common.merger.MergeState; import org.apache.uniffle.common.rpc.ClientContextServerInterceptor; import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.util.BlockIdLayout; @@ -93,10 +95,13 @@ import org.apache.uniffle.proto.ShuffleServerGrpc.ShuffleServerImplBase; import org.apache.uniffle.server.audit.ServerRpcAuditContext; import org.apache.uniffle.server.buffer.PreAllocatedBufferInfo; +import org.apache.uniffle.server.merge.MergeStatus; import org.apache.uniffle.storage.common.Storage; import org.apache.uniffle.storage.common.StorageReadMetrics; import org.apache.uniffle.storage.util.ShuffleStorageUtils; +import static org.apache.uniffle.server.merge.ShuffleMergeManager.MERGE_APP_SUFFIX; + public class ShuffleServerGrpcService extends ShuffleServerImplBase { private static final Logger LOG = LoggerFactory.getLogger(ShuffleServerGrpcService.class); @@ -144,7 +149,9 @@ public void unregisterShuffleByAppId( String responseMessage = "OK"; try { shuffleServer.getShuffleTaskManager().removeShuffleDataAsync(appId); - + if (shuffleServer.isRemoteMergeEnable()) { + shuffleServer.getShuffleTaskManager().removeShuffleDataAsync(appId + MERGE_APP_SUFFIX); + } } catch (Exception e) { status = StatusCode.INTERNAL_ERROR; } @@ -183,6 +190,11 @@ public void unregisterShuffle( String responseMessage = "OK"; try { shuffleServer.getShuffleTaskManager().removeShuffleDataAsync(appId, shuffleId); + if (shuffleServer.isRemoteMergeEnable()) { + shuffleServer + .getShuffleTaskManager() + .removeShuffleDataAsync(appId + MERGE_APP_SUFFIX, shuffleId); + } } catch (Exception e) { status = StatusCode.INTERNAL_ERROR; } @@ -297,6 +309,36 @@ public void registerShuffle( user, shuffleDataDistributionType, maxConcurrencyPerPartitionToWrite); + if (StatusCode.SUCCESS == result + && shuffleServer.isRemoteMergeEnable() + && StringUtils.isNotBlank(req.getKeyClass())) { + // The merged block is in a different domain from the original block, + // so you need to register a new app for holding the merged block. + result = + shuffleServer + .getShuffleTaskManager() + .registerShuffle( + appId + MERGE_APP_SUFFIX, + shuffleId, + partitionRanges, + new RemoteStorageInfo(remoteStoragePath, remoteStorageConf), + user, + shuffleDataDistributionType, + maxConcurrencyPerPartitionToWrite); + if (result == StatusCode.SUCCESS) { + result = + shuffleServer + .getShuffleMergeManager() + .registerShuffle( + appId, + shuffleId, + req.getKeyClass(), + req.getValueClass(), + req.getComparatorClass(), + req.getMergedBlockSize(), + req.getMergeClassLoader()); + } + } auditContext.withStatusCode(result); reply = ShuffleRegisterResponse.newBuilder().setStatus(result.toProto()).build(); responseObserver.onNext(reply); @@ -435,6 +477,10 @@ public void sendShuffleData( hasFailureOccurred = true; break; } else { + if (shuffleServer.isRemoteMergeEnable()) { + // TODO: Use ShuffleBufferWithSkipList to avoid caching block here. + shuffleServer.getShuffleMergeManager().cacheBlock(appId, shuffleId, spd); + } long toReleasedSize = spd.getTotalBlockSize(); // after each cacheShuffleData call, the `preAllocatedSize` is updated timely. manager.releasePreAllocatedSize(toReleasedSize); @@ -732,6 +778,9 @@ public void appHeartbeat( LOG.info("Get heartbeat from {}", appId); auditContext.withStatusCode(StatusCode.SUCCESS); shuffleServer.getShuffleTaskManager().refreshAppId(appId); + if (shuffleServer.isRemoteMergeEnable()) { + shuffleServer.getShuffleMergeManager().refreshAppId(appId); + } AppHeartBeatResponse response = AppHeartBeatResponse.newBuilder() .setRetMsg("") @@ -1364,6 +1413,226 @@ public void getMemoryShuffleData( } } + @Override + public void startSortMerge( + RssProtos.StartSortMergeRequest request, + StreamObserver responseObserver) { + try (ServerRpcAuditContext auditContext = createAuditContext("startSortMerge")) { + String appId = request.getAppId(); + int shuffleId = request.getShuffleId(); + int partitionId = request.getPartitionId(); + + auditContext + .withAppId(appId) + .withShuffleId(shuffleId) + .withArgs(String.format("partitionId=%d", partitionId)); + + StatusCode status = StatusCode.SUCCESS; + String msg = "OK"; + RssProtos.StartSortMergeResponse reply; + String requestInfo = + "appId[" + appId + "], shuffleId[" + shuffleId + "], partitionId[" + partitionId + "]"; + try { + Roaring64NavigableMap expectedBlockIdMap = + RssUtils.deserializeBitMap(request.getUniqueBlocksBitmap().toByteArray()); + LOG.info( + "Report " + + expectedBlockIdMap.getLongCardinality() + + " unique blocks for " + + requestInfo); + if (shuffleServer.isRemoteMergeEnable()) { + shuffleServer + .getShuffleMergeManager() + .startSortMerge(appId, shuffleId, partitionId, expectedBlockIdMap); + } else { + status = StatusCode.INTERNAL_ERROR; + msg = "Remote merge is disabled, can not report StartSortMerge!"; + } + } catch (IOException e) { + status = StatusCode.INTERNAL_ERROR; + msg = e.getMessage(); + LOG.error("Error happened when report unique blocks for {}, {}", requestInfo, e); + } + reply = + RssProtos.StartSortMergeResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(msg) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + } + + @Override + public void getSortedShuffleData( + RssProtos.GetSortedShuffleDataRequest request, + StreamObserver responseObserver) { + try (ServerRpcAuditContext auditContext = createAuditContext("getSortedShuffleData")) { + String appId = request.getAppId(); + int shuffleId = request.getShuffleId(); + int partitionId = request.getPartitionId(); + long blockId = request.getMergedBlockId(); + long timestamp = request.getTimestamp(); + auditContext + .withAppId(appId) + .withShuffleId(shuffleId) + .withArgs(String.format("partitionId=%d, blockId=%d", partitionId, blockId)); + + if (timestamp > 0) { + long transportTime = System.currentTimeMillis() - timestamp; + if (transportTime > 0) { + shuffleServer + .getGrpcMetrics() + .recordTransportTime(ShuffleServerGrpcMetrics.GET_SHUFFLE_DATA_METHOD, transportTime); + } + } + StatusCode status = StatusCode.SUCCESS; + String msg = "OK"; + RssProtos.GetSortedShuffleDataResponse reply = null; + ShuffleDataResult sdr = null; + String requestInfo = + "appId[" + + appId + + "], shuffleId[" + + shuffleId + + "], partitionId[" + + partitionId + + "]" + + "blockId[" + + blockId + + "]"; + + if (!shuffleServer.isRemoteMergeEnable()) { + msg = "Remote merge is disabled"; + status = StatusCode.INTERNAL_ERROR; + reply = + RssProtos.GetSortedShuffleDataResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(msg) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } + + MergeStatus mergeStatus = + shuffleServer + .getShuffleMergeManager() + .tryGetBlock(appId, shuffleId, partitionId, blockId); + MergeState mergeState = mergeStatus.getState(); + long blockSize = mergeStatus.getSize(); + if (mergeState == MergeState.INITED) { + msg = MergeState.INITED.name(); + reply = + RssProtos.GetSortedShuffleDataResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(msg) + .setMState(mergeState.code()) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } else if (mergeState == MergeState.MERGING && blockSize == -1) { + // Notify the client that all merged data has been read, but there may be data that has not + // yet been merged. + msg = MergeState.MERGING.name(); + reply = + RssProtos.GetSortedShuffleDataResponse.newBuilder() + .setStatus(status.toProto()) + .setNextBlockId(-1) + .setRetMsg(msg) + .setMState(mergeState.code()) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } else if (mergeState == MergeState.DONE && blockSize == -1) { + // Notify the client that all data has been read + msg = MergeState.DONE.name(); + reply = + RssProtos.GetSortedShuffleDataResponse.newBuilder() + .setStatus(status.toProto()) + .setNextBlockId(-1) + .setRetMsg(msg) + .setMState(mergeState.code()) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } else if (mergeState == MergeState.INTERNAL_ERROR) { + msg = MergeState.INTERNAL_ERROR.name(); + status = StatusCode.INTERNAL_ERROR; + reply = + RssProtos.GetSortedShuffleDataResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(msg) + .setMState(mergeState.code()) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + return; + } + + if (shuffleServer.getShuffleBufferManager().requireReadMemory(blockSize)) { + try { + long start = System.currentTimeMillis(); + sdr = + shuffleServer + .getShuffleMergeManager() + .getShuffleData(appId, shuffleId, partitionId, blockId); + long readTime = System.currentTimeMillis() - start; + ShuffleServerMetrics.counterTotalReadTime.inc(readTime); + ShuffleServerMetrics.counterTotalReadDataSize.inc(sdr.getDataLength()); + ShuffleServerMetrics.counterTotalReadLocalDataFileSize.inc(sdr.getDataLength()); + shuffleServer + .getGrpcMetrics() + .recordProcessTime(ShuffleServerGrpcMetrics.GET_SHUFFLE_DATA_METHOD, readTime); + LOG.info( + "Successfully getSortedShuffleData cost {} ms for shuffle" + + " data with {}, length is {}, state is {}", + readTime, + requestInfo, + sdr.getDataLength(), + mergeState); + auditContext.withReturnValue("len=" + sdr.getDataLength()); + reply = + RssProtos.GetSortedShuffleDataResponse.newBuilder() + .setNextBlockId(blockId + 1) // next block id + .setMState(mergeState.code()) + .setStatus(status.toProto()) + .setRetMsg(msg) + .setData(UnsafeByteOperations.unsafeWrap(sdr.getData())) + .build(); + } catch (Exception e) { + status = StatusCode.INTERNAL_ERROR; + msg = "Error happened when get shuffle data for " + requestInfo + ", " + e.getMessage(); + LOG.error(msg, e); + reply = + RssProtos.GetSortedShuffleDataResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(msg) + .build(); + } finally { + if (sdr != null) { + sdr.release(); + } + shuffleServer.getShuffleBufferManager().releaseReadMemory(blockSize); + } + } else { + status = StatusCode.INTERNAL_ERROR; + msg = "Can't require memory to get shuffle data"; + LOG.error(msg + " for " + requestInfo); + reply = + RssProtos.GetSortedShuffleDataResponse.newBuilder() + .setStatus(status.toProto()) + .setRetMsg(msg) + .build(); + } + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + } + private List toPartitionedData(SendShuffleDataRequest req) { List ret = Lists.newArrayList(); diff --git a/server/src/test/java/org/apache/uniffle/server/merge/BlockFlushFileReaderTest.java b/server/src/test/java/org/apache/uniffle/server/merge/BlockFlushFileReaderTest.java index 534c74b7d9..6b6c870511 100644 --- a/server/src/test/java/org/apache/uniffle/server/merge/BlockFlushFileReaderTest.java +++ b/server/src/test/java/org/apache/uniffle/server/merge/BlockFlushFileReaderTest.java @@ -42,7 +42,6 @@ import org.apache.uniffle.common.merger.StreamedSegment; import org.apache.uniffle.common.records.RecordsReader; import org.apache.uniffle.common.serializer.PartialInputStream; -import org.apache.uniffle.common.serializer.PartialInputStreamImpl; import org.apache.uniffle.common.serializer.SerializerUtils; import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler; @@ -121,11 +120,7 @@ public void writeTestWithMerge(String classes, @TempDir File tmpDir) throws Exce int index = 0; RecordsReader reader = new RecordsReader( - conf, - PartialInputStreamImpl.newInputStream(dataOutput, 0, dataOutput.length()), - keyClass, - valueClass, - false); + conf, PartialInputStream.newInputStream(dataOutput), keyClass, valueClass, false); while (reader.next()) { assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); diff --git a/server/src/test/java/org/apache/uniffle/server/merge/MergedResultTest.java b/server/src/test/java/org/apache/uniffle/server/merge/MergedResultTest.java index e8b09a5fd7..e8dcda7abf 100644 --- a/server/src/test/java/org/apache/uniffle/server/merge/MergedResultTest.java +++ b/server/src/test/java/org/apache/uniffle/server/merge/MergedResultTest.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.IOException; import java.io.OutputStream; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -36,7 +37,7 @@ import org.apache.uniffle.common.merger.Recordable; import org.apache.uniffle.common.merger.Segment; import org.apache.uniffle.common.records.RecordsReader; -import org.apache.uniffle.common.serializer.PartialInputStreamImpl; +import org.apache.uniffle.common.serializer.PartialInputStream; import org.apache.uniffle.common.serializer.SerializerUtils; import static org.apache.uniffle.server.ShuffleServerConf.SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE; @@ -160,7 +161,7 @@ public void testMergeSegmentToMergeResult(String classes, @TempDir File tmpDir) RecordsReader reader = new RecordsReader( rssConf, - PartialInputStreamImpl.newInputStream(buffer, 0, length), + PartialInputStream.newInputStream(ByteBuffer.wrap(buffer)), keyClass, valueClass, false); diff --git a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java index 4b545d3a73..69e2a57886 100644 --- a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java +++ b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java @@ -41,7 +41,7 @@ import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.merger.MergeState; import org.apache.uniffle.common.records.RecordsReader; -import org.apache.uniffle.common.serializer.PartialInputStreamImpl; +import org.apache.uniffle.common.serializer.PartialInputStream; import org.apache.uniffle.common.serializer.SerializerUtils; import org.apache.uniffle.common.serializer.writable.WritableSerializer; import org.apache.uniffle.common.util.BlockIdLayout; @@ -194,9 +194,8 @@ public void testMergerManager(String classes, @TempDir File tmpDir) throws Excep if (blockSize != -1) { ShuffleDataResult shuffleDataResult = mergeManager.getShuffleData(APP_ID, SHUFFLE_ID, PARTITION_ID, blockId); - PartialInputStreamImpl inputStream = - PartialInputStreamImpl.newInputStream( - shuffleDataResult.getData(), 0, shuffleDataResult.getDataLength()); + PartialInputStream inputStream = + PartialInputStream.newInputStream(shuffleDataResult.getDataBuffer()); RecordsReader reader = new RecordsReader(serverConf, inputStream, keyClass, valueClass, false); while (reader.next()) {