Skip to content

Commit

Permalink
Merge pull request #31579 Avoid length-prefix-bytes substitutions for…
Browse files Browse the repository at this point in the history
… Flink boundaries.
  • Loading branch information
robertwb authored Jul 9, 2024
2 parents ef143ae + 78bab0d commit dda0fbf
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Map;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
import org.apache.beam.runners.fnexecution.wire.LengthPrefixUnknownCoders;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderRegistry;
Expand Down Expand Up @@ -80,4 +81,19 @@ static <T> Coder<T> lookupCoder(RunnerApi.Pipeline p, String pCollectionId) {
throw new RuntimeException(exn);
}
}

static void registerKnownCoderFor(RunnerApi.Pipeline p, String pCollectionId) {
registerAsKnownCoder(p, p.getComponents().getPcollectionsOrThrow(pCollectionId).getCoderId());
}

static void registerAsKnownCoder(RunnerApi.Pipeline p, String coderId) {
RunnerApi.Coder coder = p.getComponents().getCodersOrThrow(coderId);
// It'd be more targeted to note the coder id rather than the URN,
// but the length prefixing code is invoked within a deeply nested
// sequence of static method calls.
LengthPrefixUnknownCoders.addKnownCoderUrn(coder.getSpec().getUrn());
for (String componentCoderId : coder.getComponentCoderIdsList()) {
registerAsKnownCoder(p, componentCoderId);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.util.Preconditions;
import org.apache.beam.sdk.util.construction.Environments;
import org.apache.beam.sdk.util.construction.PTransformTranslation;
import org.apache.beam.sdk.util.construction.PipelineTranslation;
import org.apache.beam.sdk.util.construction.SdkComponents;
import org.apache.beam.sdk.values.PBegin;
Expand All @@ -37,6 +38,7 @@
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
import org.apache.flink.api.common.typeinfo.TypeInformation;

Expand Down Expand Up @@ -108,6 +110,26 @@ Map<String, DataSetOrStreamT> applyBeamPTransformInternal(
// Extract the pipeline definition so that we can apply or Flink translation logic.
SdkComponents components = SdkComponents.create(pipelineOptions);
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline, components);

// Avoid swapping input and output coders for BytesCoders.
// As we have instantiated the actual coder objects here, there is no need ot length prefix them
// anyway.
// TODO(robertwb): Even better would be to avoid coding and decoding along these edges via a
// direct
// in-memory channel for embedded mode. As well as improving performance, there could be
// control-flow advantages too.
for (RunnerApi.PTransform transformProto :
pipelineProto.getComponents().getTransforms().values()) {
if (FlinkInput.URN.equals(PTransformTranslation.urnForTransformOrNull(transformProto))) {
BeamAdapterCoderUtils.registerKnownCoderFor(
pipelineProto, Iterables.getOnlyElement(transformProto.getOutputs().values()));
} else if (FlinkOutput.URN.equals(
PTransformTranslation.urnForTransformOrNull(transformProto))) {
BeamAdapterCoderUtils.registerKnownCoderFor(
pipelineProto, Iterables.getOnlyElement(transformProto.getInputs().values()));
}
}

return translator.translate(inputs, pipelineProto, executionEnvironment);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ private <InputT> FlinkBatchPortablePipelineTranslator.PTransformTranslator flink
Coder<InputT> outputCoder =
BeamAdapterCoderUtils.lookupCoder(
p, Iterables.getOnlyElement(t.getTransform().getInputsMap().values()));
// TODO(robertwb): Also handle or disable length prefix coding (for embedded mode at least).
outputMap.put(
outputId,
new MapOperator<WindowedValue<InputT>, InputT>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ public void processElement(
Coder<InputT> outputCoder =
BeamAdapterCoderUtils.lookupCoder(
p, Iterables.getOnlyElement(transform.getInputsMap().values()));
// TODO(robertwb): Also handle or disable length prefix coding (for embedded mode at least).
outputMap.put(
outputId,
inputDataStream.transform(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsInAnyOrder;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;
import java.util.Map;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
Expand Down Expand Up @@ -132,4 +138,48 @@ public void testApplyGroupingTransform() throws Exception {

assertThat(result.collect(), containsInAnyOrder(KV.of("a", 2L), KV.of("b", 1L)));
}

@Test
public void testCustomCoder() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.createCollectionsEnvironment();

DataSet<String> input = env.fromCollection(ImmutableList.of("a", "b", "c"));
DataSet<String> result =
new BeamFlinkDataSetAdapter()
.applyBeamPTransform(
input,
new PTransform<PCollection<String>, PCollection<String>>() {
@Override
public PCollection<String> expand(PCollection<String> input) {
return input.apply(withPrefix("x")).setCoder(new MyCoder());
}
});

assertThat(result.collect(), containsInAnyOrder("xa", "xb", "xc"));
}

private static class MyCoder extends Coder<String> {

private static final int CUSTOM_MARKER = 3;

@Override
public void encode(String value, OutputStream outStream) throws IOException {
outStream.write(CUSTOM_MARKER);
StringUtf8Coder.of().encode(value, outStream);
}

@Override
public String decode(InputStream inStream) throws IOException {
assert inStream.read() == CUSTOM_MARKER;
return StringUtf8Coder.of().decode(inStream);
}

@Override
public List<? extends Coder<?>> getCoderArguments() {
return null;
}

@Override
public void verifyDeterministic() throws NonDeterministicException {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
*/
package org.apache.beam.runners.fnexecution.wire;

import java.util.HashSet;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Predicate;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Coder;
Expand All @@ -28,6 +30,17 @@

/** Utilities for replacing or wrapping unknown coders with {@link LengthPrefixCoder}. */
public class LengthPrefixUnknownCoders {
private static Set<String> otherKnownCoderUrns = new HashSet<>();

/**
* Registers a coder as being of known type and as such not meriting length prefixing.
*
* @param urn The urn of the coder that should not be length prefixed.
*/
public static void addKnownCoderUrn(String urn) {
otherKnownCoderUrns.add(urn);
}

/**
* Recursively traverses the coder tree and wraps the first unknown coder in every branch with a
* {@link LengthPrefixCoder} unless an ancestor coder is itself a {@link LengthPrefixCoder}. If
Expand Down Expand Up @@ -59,7 +72,7 @@ public static String addLengthPrefixedCoder(
// with a length prefix coder or replace it with a length prefix byte array coder.
if (ModelCoders.LENGTH_PREFIX_CODER_URN.equals(urn)) {
return replaceWithByteArrayCoder ? lengthPrefixedByteArrayCoderId : coderId;
} else if (ModelCoders.urns().contains(urn)) {
} else if (ModelCoders.urns().contains(urn) || otherKnownCoderUrns.contains(urn)) {
return addForModelCoder(coderId, components, replaceWithByteArrayCoder);
} else {
return replaceWithByteArrayCoder
Expand All @@ -71,6 +84,9 @@ public static String addLengthPrefixedCoder(
private static String addForModelCoder(
String coderId, RunnerApi.Components.Builder components, boolean replaceWithByteArrayCoder) {
Coder coder = components.getCodersOrThrow(coderId);
if (coder.getComponentCoderIdsCount() == 0) {
return coderId;
}
RunnerApi.Coder.Builder builder = coder.toBuilder().clearComponentCoderIds();
for (String componentCoderId : coder.getComponentCoderIdsList()) {
builder.addComponentCoderIds(
Expand Down

0 comments on commit dda0fbf

Please sign in to comment.