Port DirectRunner WriteFiles override to SDK-agnostic APIs
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/ed6bd18b Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/ed6bd18b Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/ed6bd18b Branch: refs/heads/master Commit: ed6bd18bffe8a51d5fc2a59ff9aaa731b196d58a Parents: 02dbaef Author: Kenneth Knowles <k...@google.com> Authored: Fri May 26 16:07:45 2017 -0700 Committer: Kenneth Knowles <k...@google.com> Committed: Fri Jun 9 19:56:52 2017 -0700 ---------------------------------------------------------------------- .../core/construction/PTransformMatchers.java | 17 ++++++++--- .../direct/WriteWithShardingFactory.java | 30 ++++++++++++++------ .../direct/WriteWithShardingFactoryTest.java | 26 +++++++++++------ 3 files changed, 52 insertions(+), 21 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/ed6bd18b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java index c339891..0d27241 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java @@ -17,13 +17,14 @@ */ package org.apache.beam.runners.core.construction; +import static org.apache.beam.runners.core.construction.PTransformTranslation.WRITE_FILES_TRANSFORM_URN; + import com.google.common.base.MoreObjects; import java.io.IOException; import java.util.HashSet; import java.util.Set; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; -import org.apache.beam.sdk.io.WriteFiles; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformMatcher; import org.apache.beam.sdk.transforms.DoFn; @@ -359,10 +360,18 @@ public class PTransformMatchers { return new PTransformMatcher() { @Override public boolean matches(AppliedPTransform<?, ?, ?> application) { - if (PTransformTranslation.WRITE_FILES_TRANSFORM_URN.equals( + if (WRITE_FILES_TRANSFORM_URN.equals( PTransformTranslation.urnForTransformOrNull(application.getTransform()))) { - WriteFiles write = (WriteFiles) application.getTransform(); - return write.getSharding() == null && write.getNumShards() == null; + try { + return WriteFilesTranslation.isRunnerDeterminedSharding( + (AppliedPTransform) application); + } catch (IOException exc) { + throw new RuntimeException( + String.format( + "Transform with URN %s failed to parse: %s", + WRITE_FILES_TRANSFORM_URN, application.getTransform()), + exc); + } } return false; } http://git-wip-us.apache.org/repos/asf/beam/blob/ed6bd18b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java index 65a5a19..d8734a1 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java @@ -21,11 +21,13 @@ package org.apache.beam.runners.direct; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Supplier; import com.google.common.base.Suppliers; +import java.io.IOException; import java.io.Serializable; import java.util.Collections; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import org.apache.beam.runners.core.construction.PTransformReplacements; +import org.apache.beam.runners.core.construction.WriteFilesTranslation; import org.apache.beam.sdk.io.WriteFiles; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformOverrideFactory; @@ -43,23 +45,33 @@ import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; /** - * A {@link PTransformOverrideFactory} that overrides {@link WriteFiles} - * {@link PTransform PTransforms} with an unspecified number of shards with a write with a - * specified number of shards. The number of shards is the log base 10 of the number of input - * records, with up to 2 additional shards. + * A {@link PTransformOverrideFactory} that overrides {@link WriteFiles} {@link PTransform + * PTransforms} with an unspecified number of shards with a write with a specified number of shards. + * The number of shards is the log base 10 of the number of input records, with up to 2 additional + * shards. */ class WriteWithShardingFactory<InputT> - implements PTransformOverrideFactory<PCollection<InputT>, PDone, WriteFiles<InputT>> { + implements PTransformOverrideFactory< + PCollection<InputT>, PDone, PTransform<PCollection<InputT>, PDone>> { static final int MAX_RANDOM_EXTRA_SHARDS = 3; @VisibleForTesting static final int MIN_SHARDS_FOR_LOG = 3; @Override public PTransformReplacement<PCollection<InputT>, PDone> getReplacementTransform( - AppliedPTransform<PCollection<InputT>, PDone, WriteFiles<InputT>> transform) { + AppliedPTransform<PCollection<InputT>, PDone, PTransform<PCollection<InputT>, PDone>> + transform) { - return PTransformReplacement.of( - PTransformReplacements.getSingletonMainInput(transform), - transform.getTransform().withSharding(new LogElementShardsWithDrift<InputT>())); + try { + WriteFiles<InputT> replacement = WriteFiles.to(WriteFilesTranslation.getSink(transform)); + if (WriteFilesTranslation.isWindowedWrites(transform)) { + replacement = replacement.withWindowedWrites(); + } + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + replacement.withSharding(new LogElementShardsWithDrift<InputT>())); + } catch (IOException e) { + throw new RuntimeException(e); + } } @Override http://git-wip-us.apache.org/repos/asf/beam/blob/ed6bd18b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java index a88d95e..41d671f 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java @@ -30,6 +30,7 @@ import static org.junit.Assert.assertThat; import java.io.File; import java.io.FileReader; import java.io.Reader; +import java.io.Serializable; import java.nio.CharBuffer; import java.util.ArrayList; import java.util.Collections; @@ -53,6 +54,7 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnTester; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; @@ -71,11 +73,17 @@ import org.junit.runners.JUnit4; * Tests for {@link WriteWithShardingFactory}. */ @RunWith(JUnit4.class) -public class WriteWithShardingFactoryTest { +public class WriteWithShardingFactoryTest implements Serializable { + private static final int INPUT_SIZE = 10000; - @Rule public TemporaryFolder tmp = new TemporaryFolder(); - private WriteWithShardingFactory<Object> factory = new WriteWithShardingFactory<>(); - @Rule public final TestPipeline p = TestPipeline.create().enableAbandonedNodeEnforcement(false); + + @Rule public transient TemporaryFolder tmp = new TemporaryFolder(); + + private transient WriteWithShardingFactory<Object> factory = new WriteWithShardingFactory<>(); + + @Rule + public final transient TestPipeline p = + TestPipeline.create().enableAbandonedNodeEnforcement(false); @Test public void dynamicallyReshardedWrite() throws Exception { @@ -135,7 +143,8 @@ public class WriteWithShardingFactoryTest { DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE, "", false); - WriteFiles<Object> original = + + PTransform<PCollection<Object>, PDone> original = WriteFiles.to( new FileBasedSink<Object>(StaticValueProvider.of(outputDirectory), policy) { @Override @@ -146,9 +155,10 @@ public class WriteWithShardingFactoryTest { @SuppressWarnings("unchecked") PCollection<Object> objs = (PCollection) p.apply(Create.empty(VoidCoder.of())); - AppliedPTransform<PCollection<Object>, PDone, WriteFiles<Object>> originalApplication = - AppliedPTransform.of( - "write", objs.expand(), Collections.<TupleTag<?>, PValue>emptyMap(), original, p); + AppliedPTransform<PCollection<Object>, PDone, PTransform<PCollection<Object>, PDone>> + originalApplication = + AppliedPTransform.of( + "write", objs.expand(), Collections.<TupleTag<?>, PValue>emptyMap(), original, p); assertThat( factory.getReplacementTransform(originalApplication).getTransform(),