Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#2173] feat(remote merge): support netty for remote merge. #2202

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ public void init(Context<K, V> context) {
}
Map<Integer, List<ShuffleServerInfo>> serverInfoMap = new HashMap<>();
serverInfoMap.put(partitionId, new ArrayList<>(serverInfoSet));
String clientType =
rssJobConf.get(RssMRConfig.RSS_CLIENT_TYPE, RssMRConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE);
this.reader =
new RMRecordsReader(
appId,
Expand All @@ -134,7 +136,8 @@ public void init(Context<K, V> context) {
true,
combiner,
combiner != null,
new MRMetricsReporter(context.getReporter()));
new MRMetricsReporter(context.getReporter()),
clientType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.records.RecordsReader;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
import org.apache.uniffle.common.serializer.SerInputStream;
import org.apache.uniffle.common.serializer.SerializerFactory;
import org.apache.uniffle.common.serializer.SerializerInstance;
import org.apache.uniffle.common.serializer.SerializerUtils;
Expand Down Expand Up @@ -523,11 +523,8 @@ public void testWriteNormalWithRemoteMerge() throws Exception {
ByteBuf byteBuf = blockInfos.get(0).getData();
RecordsReader<Text, Text> reader =
new RecordsReader<>(
rssConf,
PartialInputStreamImpl.newInputStream(byteBuf.nioBuffer()),
Text.class,
Text.class,
false);
rssConf, SerInputStream.newInputStream(byteBuf), Text.class, Text.class, false, false);
reader.init();
int index = 0;
while (reader.next()) {
assertEquals(SerializerUtils.genData(Text.class, index), reader.getCurrentKey());
Expand Down Expand Up @@ -610,10 +607,12 @@ public void testWriteNormalWithRemoteMergeAndCombine() throws Exception {
RecordsReader<Text, IntWritable> reader =
new RecordsReader<>(
rssConf,
PartialInputStreamImpl.newInputStream(byteBuf.nioBuffer()),
SerInputStream.newInputStream(byteBuf),
Text.class,
IntWritable.class,
false,
false);
reader.init();
int index = 0;
while (reader.next()) {
int aimValue = index;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.hadoop.mapreduce.task.reduce;

import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -29,6 +27,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import io.netty.buffer.ByteBuf;
import org.apache.hadoop.io.DataInputBuffer;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.RawComparator;
Expand Down Expand Up @@ -58,13 +57,15 @@
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.merger.Merger;
import org.apache.uniffle.common.merger.Segment;
import org.apache.uniffle.common.serializer.DynBufferSerOutputStream;
import org.apache.uniffle.common.serializer.SerOutputStream;
import org.apache.uniffle.common.serializer.Serializer;
import org.apache.uniffle.common.serializer.SerializerFactory;
import org.apache.uniffle.common.serializer.SerializerInstance;
import org.apache.uniffle.common.serializer.SerializerUtils;

import static org.apache.uniffle.common.serializer.SerializerUtils.genData;
import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes;
import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBuffer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
Expand Down Expand Up @@ -136,12 +137,10 @@ public void testReadShuffleWithoutCombine() throws Exception {
combiner,
false,
null);
ByteBuffer byteBuffer =
ByteBuffer.wrap(
genSortedRecordBytes(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, 1));
ByteBuf byteBuf = genSortedRecordBuffer(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, 1);
ShuffleServerClient serverClient =
new MockedShuffleServerClient(
new int[] {PARTITION_ID}, new ByteBuffer[][] {{byteBuffer}}, blockIds);
new int[] {PARTITION_ID}, new ByteBuf[][] {{byteBuf}}, blockIds);
RMRecordsReader readerSpy = spy(reader);
doReturn(serverClient).when(readerSpy).createShuffleServerClient(any());

Expand Down Expand Up @@ -170,6 +169,7 @@ public void testReadShuffleWithoutCombine() throws Exception {
index++;
}
assertEquals(RECORDS_NUM, index);
byteBuf.release();
}
}

Expand Down Expand Up @@ -219,20 +219,21 @@ public void testReadShuffleWithCombine() throws Exception {
List<Segment> segments = new ArrayList<>();
segments.add(
SerializerUtils.genMemorySegment(
rssConf, keyClass, valueClass, 0L, 0, 2, RECORDS_NUM, true));
rssConf, keyClass, valueClass, 0L, 0, 2, RECORDS_NUM, true, false));
segments.add(
SerializerUtils.genMemorySegment(
rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM, true));
rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM, true, false));
segments.add(
SerializerUtils.genMemorySegment(
rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM, true));
ByteArrayOutputStream output = new ByteArrayOutputStream();
rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM, true, false));
segments.forEach(segment -> segment.init());
SerOutputStream output = new DynBufferSerOutputStream();
Merger.merge(rssConf, output, segments, keyClass, valueClass, comparator, true);
output.close();
ByteBuffer byteBuffer = ByteBuffer.wrap(output.toByteArray());
ByteBuf byteBuf = output.toByteBuf();
ShuffleServerClient serverClient =
new MockedShuffleServerClient(
new int[] {PARTITION_ID}, new ByteBuffer[][] {{byteBuffer}}, blockIds);
new int[] {PARTITION_ID}, new ByteBuf[][] {{byteBuf}}, blockIds);
RMRecordsReader reader =
new RMRecordsReader(
APP_ID,
Expand Down Expand Up @@ -280,6 +281,7 @@ public void testReadShuffleWithCombine() throws Exception {
index++;
}
assertEquals(RECORDS_NUM * 2, index);
byteBuf.release();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public class RMRssShuffle implements ExceptionReporter {
private ShuffleInputEventHandlerOrderedGrouped eventHandler;
private final TezTaskAttemptID tezTaskAttemptID;
private final String srcNameTrimmed;
private final String clientType;
private Map<Integer, List<ShuffleServerInfo>> partitionToServers;

private AtomicBoolean isShutDown = new AtomicBoolean(false);
Expand Down Expand Up @@ -101,6 +102,8 @@ public RMRssShuffle(
this.numInputs = numInputs;
this.shuffleId = shuffleId;
this.applicationAttemptId = applicationAttemptId;
this.clientType =
conf.get(RssTezConfig.RSS_CLIENT_TYPE, RssTezConfig.RSS_CLIENT_TYPE_DEFAULT_VALUE);
this.appId = this.applicationAttemptId.toString();
this.srcNameTrimmed = TezUtilsInternal.cleanVertexName(inputContext.getSourceVertexName());
LOG.info(srcNameTrimmed + ": Shuffle assigned with " + numInputs + " inputs.");
Expand Down Expand Up @@ -254,7 +257,8 @@ public RMRecordsReader createRMRecordsReader(Set partitionIds) {
false,
(inc) -> {
inputRecordCounter.increment(inc);
});
},
this.clientType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
Expand All @@ -28,6 +29,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import io.netty.buffer.ByteBuf;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.IntWritable;
Expand Down Expand Up @@ -74,7 +76,7 @@
import static org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS;
import static org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_KEY_COMPARATOR_CLASS;
import static org.apache.tez.runtime.library.api.TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS;
import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes;
import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBuffer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
Expand Down Expand Up @@ -158,12 +160,9 @@ public void testReadShuffleData() throws Exception {
false,
null);
RMRecordsReader recordsReaderSpy = spy(recordsReader);
ByteBuffer[][] buffers =
new ByteBuffer[][] {
{
ByteBuffer.wrap(
genSortedRecordBytes(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, duplicated))
}
ByteBuf[][] buffers =
new ByteBuf[][] {
{genSortedRecordBuffer(rssConf, keyClass, valueClass, 0, 1, RECORDS_NUM, duplicated)}
};
ShuffleServerClient serverClient =
new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers, blockIds);
Expand Down Expand Up @@ -228,6 +227,7 @@ public void testReadShuffleData() throws Exception {
index++;
}
assertEquals(RECORDS_NUM, index);
Arrays.stream(buffers).forEach(bs -> Arrays.stream(bs).forEach(b -> b.release()));
}

@Test
Expand Down Expand Up @@ -309,15 +309,13 @@ public void testReadMultiPartitionShuffleData() throws Exception {
false,
null);
RMRecordsReader recordsReaderSpy = spy(recordsReader);
ByteBuffer[][] buffers = new ByteBuffer[3][2];
ByteBuf[][] buffers = new ByteBuf[3][2];
for (int i = 0; i < 3; i++) {
buffers[i][0] =
ByteBuffer.wrap(
genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, duplicated));
genSortedRecordBuffer(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, duplicated);
buffers[i][1] =
ByteBuffer.wrap(
genSortedRecordBytes(
rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, duplicated));
genSortedRecordBuffer(
rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, duplicated);
}
ShuffleServerClient serverClient =
new MockedShuffleServerClient(
Expand Down Expand Up @@ -396,6 +394,7 @@ public void testReadMultiPartitionShuffleData() throws Exception {
index++;
}
assertEquals(RECORDS_NUM * 6, index);
Arrays.stream(buffers).forEach(bs -> Arrays.stream(bs).forEach(b -> b.release()));
}

public static DataMovementEvent createDataMovementEvent(int partition, String path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -73,7 +72,7 @@
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.records.RecordsReader;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
import org.apache.uniffle.common.serializer.SerInputStream;
import org.apache.uniffle.common.serializer.SerializerFactory;
import org.apache.uniffle.common.serializer.SerializerInstance;
import org.apache.uniffle.common.serializer.SerializerUtils;
Expand Down Expand Up @@ -617,11 +616,8 @@ public void testWriteWithRemoteMerge() throws Exception {
buf.readBytes(bytes);
RecordsReader<Text, Text> reader =
new RecordsReader<>(
rssConf,
PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes)),
Text.class,
Text.class,
false);
rssConf, SerInputStream.newInputStream(buf), Text.class, Text.class, false, false);
reader.init();
int index = 0;
while (reader.next()) {
assertEquals(SerializerUtils.genData(Text.class, index), reader.getCurrentKey());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Map;
import java.util.Random;

import io.netty.buffer.Unpooled;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.IntWritable;
Expand All @@ -41,8 +42,7 @@

import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.serializer.DeserializationStream;
import org.apache.uniffle.common.serializer.PartialInputStream;
import org.apache.uniffle.common.serializer.PartialInputStreamImpl;
import org.apache.uniffle.common.serializer.SerInputStream;
import org.apache.uniffle.common.serializer.SerializerFactory;
import org.apache.uniffle.common.serializer.SerializerInstance;

Expand Down Expand Up @@ -203,9 +203,11 @@ public void testReadWriteWithRemoteMergeAndNoSort() throws IOException {
buffer.addRecord(genData(Text.class, i), genData(IntWritable.class, i));
}
byte[] bytes = buffer.getData();
PartialInputStream inputStream = PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes));
SerInputStream inputStream =
SerInputStream.newInputStream(Unpooled.wrappedBuffer(ByteBuffer.wrap(bytes)));
DeserializationStream dStream =
instance.deserializeStream(inputStream, Text.class, IntWritable.class, false);
instance.deserializeStream(inputStream, Text.class, IntWritable.class, false, false);
dStream.init();
for (int i = 0; i < RECORDS_NUM; i++) {
assertTrue(dStream.nextRecord());
assertEquals(genData(Text.class, i), dStream.getCurrentKey());
Expand Down Expand Up @@ -240,9 +242,11 @@ public void testReadWriteWithRemoteMergeAndSort() throws IOException {
buffer.addRecord(genData(Text.class, i), genData(IntWritable.class, i));
}
byte[] bytes = buffer.getData();
PartialInputStream inputStream = PartialInputStreamImpl.newInputStream(ByteBuffer.wrap(bytes));
SerInputStream inputStream =
SerInputStream.newInputStream(Unpooled.wrappedBuffer(ByteBuffer.wrap(bytes)));
DeserializationStream dStream =
instance.deserializeStream(inputStream, Text.class, IntWritable.class, false);
instance.deserializeStream(inputStream, Text.class, IntWritable.class, false, false);
dStream.init();
for (int i = 0; i < RECORDS_NUM; i++) {
assertTrue(dStream.nextRecord());
assertEquals(genData(Text.class, i), dStream.getCurrentKey());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.tez.runtime.library.input;

import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Comparator;
Expand All @@ -29,6 +28,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import io.netty.buffer.ByteBuf;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.IntWritable;
Expand Down Expand Up @@ -72,12 +72,14 @@
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.merger.Merger;
import org.apache.uniffle.common.merger.Segment;
import org.apache.uniffle.common.serializer.DynBufferSerOutputStream;
import org.apache.uniffle.common.serializer.SerOutputStream;
import org.apache.uniffle.common.serializer.SerializerUtils;
import org.apache.uniffle.common.util.BlockIdLayout;

import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_DESTINATION_VERTEX_ID;
import static org.apache.tez.common.RssTezConfig.RSS_SHUFFLE_SOURCE_VERTEX_ID;
import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBytes;
import static org.apache.uniffle.common.serializer.SerializerUtils.genSortedRecordBuffer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
Expand Down Expand Up @@ -166,10 +168,11 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 1L, 0, 2, RECORDS_NUM));
segments.add(
SerializerUtils.genMemorySegment(rssConf, keyClass, valueClass, 2L, 1, 2, RECORDS_NUM));
ByteArrayOutputStream output = new ByteArrayOutputStream();
segments.forEach(segment -> segment.init());
SerOutputStream output = new DynBufferSerOutputStream();
Merger.merge(rssConf, output, segments, keyClass, valueClass, comparator, false);
output.close();
ByteBuffer[][] buffers = new ByteBuffer[][] {{ByteBuffer.wrap(output.toByteArray())}};
ByteBuf[][] buffers = new ByteBuf[][] {{output.toByteBuf()}};
ShuffleServerClient serverClient =
new MockedShuffleServerClient(new int[] {PARTITION_ID}, buffers, blockIds);
RMRecordsReader recordsReader =
Expand Down Expand Up @@ -362,15 +365,12 @@ public Void answer(InvocationOnMock invocation) throws Throwable {
false,
null);
RMRecordsReader recordsReaderSpy = spy(recordsReader);
ByteBuffer[][] buffers = new ByteBuffer[3][2];
ByteBuf[][] buffers = new ByteBuf[3][2];
for (int i = 0; i < 3; i++) {
buffers[i][0] =
ByteBuffer.wrap(
genSortedRecordBytes(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 1));
buffers[i][0] = genSortedRecordBuffer(rssConf, keyClass, valueClass, i, 3, RECORDS_NUM, 1);
buffers[i][1] =
ByteBuffer.wrap(
genSortedRecordBytes(
rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 1));
genSortedRecordBuffer(
rssConf, keyClass, valueClass, i + RECORDS_NUM * 3, 3, RECORDS_NUM, 1);
}
ShuffleServerClient serverClient =
new MockedShuffleServerClient(
Expand Down
Loading
Loading