diff --git a/common/src/main/java/org/apache/uniffle/common/merger/KeyValueIterator.java b/common/src/main/java/org/apache/uniffle/common/merger/KeyValueIterator.java new file mode 100644 index 0000000000..22344f8039 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/merger/KeyValueIterator.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.merger; + +import java.io.IOException; + +public interface KeyValueIterator { + + K getCurrentKey(); + + V getCurrentValue(); + + boolean next() throws IOException; + + void close() throws IOException; +} diff --git a/common/src/main/java/org/apache/uniffle/common/merger/MergeState.java b/common/src/main/java/org/apache/uniffle/common/merger/MergeState.java new file mode 100644 index 0000000000..3439d07010 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/merger/MergeState.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.merger; + +public enum MergeState { + DONE(0), + INITED(1), + MERGING(2), + INTERNAL_ERROR(3); + + private final int code; + + MergeState(int code) { + this.code = code; + } + + public int code() { + return code; + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/merger/Merger.java b/common/src/main/java/org/apache/uniffle/common/merger/Merger.java new file mode 100644 index 0000000000..7fcbe2d9b4 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/merger/Merger.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.merger; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.function.Function; + +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.io.RawComparator; +import org.apache.hadoop.util.PriorityQueue; + +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.records.RecordsWriter; + +public class Merger { + + public static class MergeQueue extends PriorityQueue implements KeyValueIterator { + + private final RssConf rssConf; + private final List segments; + private final Class keyClass; + private final Class valueClass; + private Comparator comparator; + private boolean raw; + + private Object currentKey; + private Object currentValue; + private Segment minSegment; + private Function popSegmentHook; + + public MergeQueue( + RssConf rssConf, + List segments, + Class keyClass, + Class valueClass, + Comparator comparator, + boolean raw) { + this.rssConf = rssConf; + this.segments = segments; + this.keyClass = keyClass; + this.valueClass = valueClass; + if (comparator == null) { + throw new RssException("comparator is null!"); + } + this.raw = raw; + this.comparator = comparator; + } + + public void setPopSegmentHook(Function popSegmentHook) { + this.popSegmentHook = popSegmentHook; + } + + @Override + protected boolean lessThan(Object o1, Object o2) { + if (raw) { + Segment s1 = (Segment) o1; + Segment s2 = (Segment) o2; + DataOutputBuffer key1 = (DataOutputBuffer) s1.getCurrentKey(); + DataOutputBuffer key2 = (DataOutputBuffer) s2.getCurrentKey(); + int c = + ((RawComparator) comparator) + .compare(key1.getData(), 0, key1.getLength(), key2.getData(), 0, key2.getLength()); + return c < 0 || ((c == 0) && s1.getId() < s2.getId()); + } else { + Segment s1 = (Segment) o1; + Segment s2 = (Segment) o2; + Object key1 = s1.getCurrentKey(); + Object key2 = s2.getCurrentKey(); + int c = comparator.compare(key1, key2); + return c < 0 || ((c == 0) && s1.getId() < s2.getId()); + } + } + + public void init() throws IOException { + List segmentsToMerge = new ArrayList(); + for (Segment segment : segments) { + boolean hasNext = segment.next(); + if (hasNext) { + segmentsToMerge.add(segment); + } else { + segment.close(); + } + } + initialize(segmentsToMerge.size()); + clear(); + for (Segment segment : segmentsToMerge) { + put(segment); + } + } + + @Override + public Object getCurrentKey() { + return currentKey; + } + + @Override + public Object getCurrentValue() { + return currentValue; + } + + @Override + public boolean next() throws IOException { + if (size() == 0) { + resetKeyValue(); + return false; + } + + if (minSegment != null) { + adjustPriorityQueue(minSegment); + if (size() == 0) { + minSegment = null; + resetKeyValue(); + return false; + } + } + minSegment = top(); + currentKey = minSegment.getCurrentKey(); + currentValue = minSegment.getCurrentValue(); + return true; + } + + private void resetKeyValue() { + currentKey = null; + currentValue = null; + } + + private void adjustPriorityQueue(Segment segment) throws IOException { + if (segment.next()) { + adjustTop(); + } else { + pop(); + segment.close(); + if (popSegmentHook != null) { + Segment newSegment = popSegmentHook.apply((int) segment.getId()); + if (newSegment != null) { + if (newSegment.next()) { + put(newSegment); + } else { + newSegment.close(); + } + } + } + } + } + + void merge(OutputStream output) throws IOException { + RecordsWriter writer = + new RecordsWriter(rssConf, output, keyClass, valueClass, raw); + boolean recorded = true; + while (this.next()) { + writer.append(this.getCurrentKey(), this.getCurrentValue()); + if (output instanceof Recordable) { + recorded = + ((Recordable) output) + .record(writer.getTotalBytesWritten(), () -> writer.flush(), false); + } + } + writer.flush(); + if (!recorded) { + ((Recordable) output).record(writer.getTotalBytesWritten(), null, true); + } + writer.close(); + } + + @Override + public void close() throws IOException { + Segment segment; + while ((segment = pop()) != null) { + segment.close(); + } + } + } + + public static void merge( + RssConf conf, + OutputStream output, + List segments, + Class keyClass, + Class valueClass, + Comparator comparator, + boolean raw) + throws IOException { + MergeQueue mergeQueue = new MergeQueue(conf, segments, keyClass, valueClass, comparator, raw); + mergeQueue.init(); + mergeQueue.merge(output); + mergeQueue.close(); + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/merger/Recordable.java b/common/src/main/java/org/apache/uniffle/common/merger/Recordable.java new file mode 100644 index 0000000000..79604b7f61 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/merger/Recordable.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.merger; + +import java.io.IOException; + +public interface Recordable { + + @FunctionalInterface + interface Flushable { + void flush() throws IOException; + } + + boolean record(long written, Flushable flush, boolean force) throws IOException; +} diff --git a/common/src/main/java/org/apache/uniffle/common/merger/Segment.java b/common/src/main/java/org/apache/uniffle/common/merger/Segment.java new file mode 100644 index 0000000000..f8a7301229 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/merger/Segment.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.merger; + +import java.io.IOException; + +public abstract class Segment { + + private long id; + + public Segment(long id) { + this.id = id; + } + + public abstract boolean next() throws IOException; + + public abstract Object getCurrentKey(); + + public abstract Object getCurrentValue(); + + public long getId() { + return this.id; + } + + public abstract void close() throws IOException; +} diff --git a/common/src/main/java/org/apache/uniffle/common/merger/StreamedSegment.java b/common/src/main/java/org/apache/uniffle/common/merger/StreamedSegment.java new file mode 100644 index 0000000000..1966a3e511 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/merger/StreamedSegment.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.merger; + +import java.io.File; +import java.io.IOException; + +import io.netty.buffer.ByteBuf; + +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.records.RecordsReader; +import org.apache.uniffle.common.serializer.PartialInputStream; +import org.apache.uniffle.common.serializer.PartialInputStreamImpl; + +public class StreamedSegment extends Segment { + + private RecordsReader reader; + ByteBuf byteBuf = null; + + public StreamedSegment( + RssConf rssConf, + PartialInputStream inputStream, + long blockId, + Class keyClass, + Class valueClass, + boolean raw) { + super(blockId); + this.reader = new RecordsReader<>(rssConf, inputStream, keyClass, valueClass, raw); + } + + public StreamedSegment( + RssConf rssConf, ByteBuf byteBuf, long blockId, Class keyClass, Class valueClass, boolean raw) + throws IOException { + super(blockId); + this.byteBuf = byteBuf; + this.byteBuf.retain(); + byte[] buffer = byteBuf.array(); + this.reader = + new RecordsReader<>( + rssConf, + PartialInputStreamImpl.newInputStream(buffer, 0, buffer.length), + keyClass, + valueClass, + raw); + } + + // The buffer must be sorted by key + public StreamedSegment( + RssConf rssConf, byte[] buffer, long blockId, Class keyClass, Class valueClass, boolean raw) + throws IOException { + super(blockId); + this.reader = + new RecordsReader<>( + rssConf, + PartialInputStreamImpl.newInputStream(buffer, 0, buffer.length), + keyClass, + valueClass, + raw); + } + + public StreamedSegment( + RssConf rssConf, + File file, + long start, + long end, + long blockId, + Class keyClass, + Class valueClass, + boolean raw) + throws IOException { + super(blockId); + this.reader = + new RecordsReader( + rssConf, + PartialInputStreamImpl.newInputStream(file, start, end), + keyClass, + valueClass, + raw); + } + + @Override + public boolean next() throws IOException { + return this.reader.next(); + } + + @Override + public Object getCurrentKey() { + return this.reader.getCurrentKey(); + } + + @Override + public Object getCurrentValue() { + return this.reader.getCurrentValue(); + } + + @Override + public void close() throws IOException { + if (byteBuf != null) { + this.byteBuf.release(); + this.byteBuf = null; + } + if (this.reader != null) { + this.reader.close(); + this.reader = null; + } + } +} diff --git a/common/src/test/java/org/apache/uniffle/common/merger/MergerTest.java b/common/src/test/java/org/apache/uniffle/common/merger/MergerTest.java new file mode 100644 index 0000000000..1757ad0047 --- /dev/null +++ b/common/src/test/java/org/apache/uniffle/common/merger/MergerTest.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.merger; + +import java.io.File; +import java.io.FileOutputStream; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +import org.apache.hadoop.io.RawComparator; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.records.RecordsReader; +import org.apache.uniffle.common.serializer.PartialInputStreamImpl; +import org.apache.uniffle.common.serializer.SerializerUtils; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class MergerTest { + + private static final int RECORDS = 1009; + private static final int SEGMENTS = 4; + + @ParameterizedTest + @ValueSource( + strings = { + "org.apache.hadoop.io.Text,org.apache.hadoop.io.IntWritable", + }) + public void testMergeSegmentToFile(String classes, @TempDir File tmpDir) throws Exception { + // 1 Parse arguments + String[] classArray = classes.split(","); + Class keyClass = SerializerUtils.getClassByName(classArray[0]); + Class valueClass = SerializerUtils.getClassByName(classArray[1]); + + // 2 Construct segments, then merge + RssConf rssConf = new RssConf(); + List segments = new ArrayList<>(); + Comparator comparator = SerializerUtils.getComparator(keyClass); + for (int i = 0; i < SEGMENTS; i++) { + if (i % 2 == 0) { + segments.add( + SerializerUtils.genMemorySegment( + rssConf, + keyClass, + valueClass, + i, + i, + SEGMENTS, + RECORDS, + comparator instanceof RawComparator)); + } else { + segments.add( + SerializerUtils.genFileSegment( + rssConf, + keyClass, + valueClass, + i, + i, + SEGMENTS, + RECORDS, + tmpDir, + comparator instanceof RawComparator)); + } + } + File mergedFile = new File(tmpDir, "data.merged"); + FileOutputStream outputStream = new FileOutputStream(mergedFile); + Merger.merge( + rssConf, + outputStream, + segments, + keyClass, + valueClass, + comparator, + comparator instanceof RawComparator); + outputStream.close(); + + // 3 Check the merged file + RecordsReader reader = + new RecordsReader( + rssConf, + PartialInputStreamImpl.newInputStream(mergedFile, 0, mergedFile.length()), + keyClass, + valueClass, + false); + int index = 0; + while (reader.next()) { + assertEquals(SerializerUtils.genData(keyClass, index), reader.getCurrentKey()); + assertEquals(SerializerUtils.genData(valueClass, index), reader.getCurrentValue()); + index++; + } + assertEquals(RECORDS * SEGMENTS, index); + reader.close(); + } +} diff --git a/common/src/test/java/org/apache/uniffle/common/serializer/SerializerUtils.java b/common/src/test/java/org/apache/uniffle/common/serializer/SerializerUtils.java index b57114fb1f..d5675182bf 100644 --- a/common/src/test/java/org/apache/uniffle/common/serializer/SerializerUtils.java +++ b/common/src/test/java/org/apache/uniffle/common/serializer/SerializerUtils.java @@ -17,12 +17,22 @@ package org.apache.uniffle.common.serializer; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; import java.util.Comparator; import com.google.common.base.Objects; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; +import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.merger.Segment; +import org.apache.uniffle.common.merger.StreamedSegment; +import org.apache.uniffle.common.records.RecordsWriter; + public class SerializerUtils { public static class SomeClass { @@ -125,4 +135,97 @@ public int compare(Integer o1, Integer o2) { } return null; } + + public static byte[] genSortedRecordBytes( + RssConf rssConf, + Class keyClass, + Class valueClass, + int start, + int interval, + int length, + int replica) + throws IOException { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + genSortedRecord(rssConf, keyClass, valueClass, start, interval, length, output, replica); + return output.toByteArray(); + } + + public static Segment genMemorySegment( + RssConf rssConf, + Class keyClass, + Class valueClass, + long blockId, + int start, + int interval, + int length) + throws IOException { + return genMemorySegment(rssConf, keyClass, valueClass, blockId, start, interval, length, false); + } + + public static Segment genMemorySegment( + RssConf rssConf, + Class keyClass, + Class valueClass, + long blockId, + int start, + int interval, + int length, + boolean raw) + throws IOException { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + genSortedRecord(rssConf, keyClass, valueClass, start, interval, length, output, 1); + return new StreamedSegment(rssConf, output.toByteArray(), blockId, keyClass, valueClass, raw); + } + + public static Segment genFileSegment( + RssConf rssConf, + Class keyClass, + Class valueClass, + long blockId, + int start, + int interval, + int length, + File tmpDir) + throws IOException { + return genFileSegment( + rssConf, keyClass, valueClass, blockId, start, interval, length, tmpDir, false); + } + + public static Segment genFileSegment( + RssConf rssConf, + Class keyClass, + Class valueClass, + long blockId, + int start, + int interval, + int length, + File tmpDir, + boolean raw) + throws IOException { + File file = new File(tmpDir, "data." + start); + genSortedRecord( + rssConf, keyClass, valueClass, start, interval, length, new FileOutputStream(file), 1); + return new StreamedSegment(rssConf, file, 0, file.length(), blockId, keyClass, valueClass, raw); + } + + private static void genSortedRecord( + RssConf rssConf, + Class keyClass, + Class valueClass, + int start, + int interval, + int length, + OutputStream output, + int replica) + throws IOException { + RecordsWriter writer = new RecordsWriter(rssConf, output, keyClass, valueClass, false); + for (int i = 0; i < length; i++) { + for (int j = 0; j < replica; j++) { + writer.append( + SerializerUtils.genData(keyClass, start + i * interval), + SerializerUtils.genData(valueClass, start + i * interval)); + } + } + writer.close(); + } }