This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 7a309c2  [BEAM-10869] Make WriteToPubsub output serialized 
PubsubMessage proto bytes when using runner v2
     new c9e2580  Merge pull request #12806 from [BEAM-10869] Make 
WriteToPubsub output serialized PubsubMessage proto bytes when using runner v2
7a309c2 is described below

commit 7a309c2ce602fba67351f36067f9c65b150b86a8
Author: Boyuan Zhang <boyu...@google.com>
AuthorDate: Wed Sep 9 21:05:02 2020 -0700

    [BEAM-10869] Make WriteToPubsub output serialized PubsubMessage proto bytes 
when using runner v2
---
 .../beam/runners/dataflow/DataflowRunner.java      | 150 ++++++++++++++++++---
 .../beam/sdk/io/gcp/pubsub/ExternalRead.java       |  20 +--
 .../beam/sdk/io/gcp/pubsub/ExternalWrite.java      |  19 +--
 .../beam/sdk/io/gcp/pubsub/PubsubMessages.java     |  58 ++++++++
 sdks/python/apache_beam/io/gcp/pubsub.py           |  46 ++++---
 sdks/python/apache_beam/io/gcp/pubsub_test.py      |   5 +-
 .../runners/dataflow/dataflow_runner.py            |   8 +-
 .../apache_beam/runners/direct/direct_runner.py    |  19 ++-
 8 files changed, 240 insertions(+), 85 deletions(-)

diff --git 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index 8931143..1e0415d 100644
--- 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -79,6 +79,7 @@ import 
org.apache.beam.runners.core.construction.UnconsumedReads;
 import org.apache.beam.runners.core.construction.WriteFilesTranslation;
 import 
org.apache.beam.runners.dataflow.DataflowPipelineTranslator.JobSpecification;
 import 
org.apache.beam.runners.dataflow.StreamingViewOverrides.StreamingCreatePCollectionViewFactory;
+import 
org.apache.beam.runners.dataflow.TransformTranslator.StepTranslationContext;
 import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions;
 import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions;
 import 
org.apache.beam.runners.dataflow.options.DataflowPipelineWorkerPoolOptions;
@@ -92,6 +93,7 @@ import org.apache.beam.sdk.Pipeline.PipelineVisitor;
 import org.apache.beam.sdk.PipelineResult.State;
 import org.apache.beam.sdk.PipelineRunner;
 import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.Coder.NonDeterministicException;
 import org.apache.beam.sdk.coders.KvCoder;
@@ -109,6 +111,7 @@ import org.apache.beam.sdk.io.fs.ResourceId;
 import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessage;
 import 
org.apache.beam.sdk.io.gcp.pubsub.PubsubMessageWithAttributesAndMessageIdCoder;
 import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessageWithAttributesCoder;
+import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessages;
 import org.apache.beam.sdk.io.gcp.pubsub.PubsubUnboundedSink;
 import org.apache.beam.sdk.io.gcp.pubsub.PubsubUnboundedSource;
 import org.apache.beam.sdk.options.PipelineOptions;
@@ -129,6 +132,7 @@ import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupIntoBatches;
 import org.apache.beam.sdk.transforms.Impulse;
+import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Reshuffle;
@@ -155,6 +159,7 @@ import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TypeDescriptor;
 import org.apache.beam.sdk.values.TypeDescriptors;
 import org.apache.beam.sdk.values.ValueWithRecordId;
 import org.apache.beam.sdk.values.WindowingStrategy;
@@ -476,10 +481,18 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
                 new StreamingPubsubIOReadOverrideFactory()));
       }
       if (!hasExperiment(options, "enable_custom_pubsub_sink")) {
-        overridesBuilder.add(
-            PTransformOverride.of(
-                PTransformMatchers.classEqualTo(PubsubUnboundedSink.class),
-                new StreamingPubsubIOWriteOverrideFactory(this)));
+        if (hasExperiment(options, "use_runner_v2")
+            || hasExperiment(options, "use_unified_worker")) {
+          overridesBuilder.add(
+              PTransformOverride.of(
+                  PTransformMatchers.classEqualTo(PubsubUnboundedSink.class),
+                  new DataflowWriteToPubsubRunnerV2OverrideFactory()));
+        } else {
+          overridesBuilder.add(
+              PTransformOverride.of(
+                  PTransformMatchers.classEqualTo(PubsubUnboundedSink.class),
+                  new StreamingPubsubIOWriteOverrideFactory(this)));
+        }
       }
       overridesBuilder.add(
           PTransformOverride.of(
@@ -1441,6 +1454,39 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
 
   /**
    * Suppress application of {@link PubsubUnboundedSink#expand} in streaming 
mode so that we can
+   * instead defer to Windmill's implementation when using Dataflow runner v2.
+   */
+  private static class DataflowRunnerV2PubsubSink extends 
PTransform<PCollection<byte[]>, PDone> {
+
+    private final PubsubUnboundedSink transform;
+
+    public DataflowRunnerV2PubsubSink(PubsubUnboundedSink transform) {
+      this.transform = transform;
+    }
+
+    PubsubUnboundedSink getOverriddenTransform() {
+      return transform;
+    }
+
+    @Override
+    public PDone expand(PCollection<byte[]> input) {
+      return PDone.in(input.getPipeline());
+    }
+
+    @Override
+    protected String getKindString() {
+      return "DataflowRunnerV2PubsubSink";
+    }
+
+    static {
+      DataflowPipelineTranslator.registerTransformTranslator(
+          DataflowRunnerV2PubsubSink.class,
+          new 
StreamingPubsubSinkTranslators.DataflowRunnerV2PubsubSinkTranslator());
+    }
+  }
+
+  /**
+   * Suppress application of {@link PubsubUnboundedSink#expand} in streaming 
mode so that we can
    * instead defer to Windmill's implementation.
    */
   private static class StreamingPubsubIOWrite
@@ -1469,21 +1515,45 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
 
     static {
       DataflowPipelineTranslator.registerTransformTranslator(
-          StreamingPubsubIOWrite.class, new 
StreamingPubsubIOWriteTranslator());
+          StreamingPubsubIOWrite.class,
+          new 
StreamingPubsubSinkTranslators.StreamingPubsubIOWriteTranslator());
     }
   }
 
-  /** Rewrite {@link StreamingPubsubIOWrite} to the appropriate internal node. 
*/
-  private static class StreamingPubsubIOWriteTranslator
-      implements TransformTranslator<StreamingPubsubIOWrite> {
+  private static class StreamingPubsubSinkTranslators {
+    /** Rewrite {@link StreamingPubsubIOWrite} to the appropriate internal 
node. */
+    static class StreamingPubsubIOWriteTranslator
+        implements TransformTranslator<StreamingPubsubIOWrite> {
 
-    @Override
-    public void translate(StreamingPubsubIOWrite transform, TranslationContext 
context) {
-      checkArgument(
-          context.getPipelineOptions().isStreaming(),
-          "StreamingPubsubIOWrite is only for streaming pipelines.");
-      PubsubUnboundedSink overriddenTransform = 
transform.getOverriddenTransform();
-      StepTranslationContext stepContext = context.addStep(transform, 
"ParallelWrite");
+      @Override
+      public void translate(StreamingPubsubIOWrite transform, 
TranslationContext context) {
+        checkArgument(
+            context.getPipelineOptions().isStreaming(),
+            "StreamingPubsubIOWrite is only for streaming pipelines.");
+        StepTranslationContext stepContext = context.addStep(transform, 
"ParallelWrite");
+        StreamingPubsubSinkTranslators.translate(
+            transform.getOverriddenTransform(), stepContext, 
context.getInput(transform));
+      }
+    }
+
+    /** Rewrite {@link DataflowRunnerV2PubsubSink} to the appropriate internal 
node. */
+    static class DataflowRunnerV2PubsubSinkTranslator
+        implements TransformTranslator<DataflowRunnerV2PubsubSink> {
+      @Override
+      public void translate(DataflowRunnerV2PubsubSink transform, 
TranslationContext context) {
+        checkArgument(
+            context.getPipelineOptions().isStreaming(),
+            "StreamingPubsubIOWrite is only for streaming pipelines.");
+        StepTranslationContext stepContext = context.addStep(transform, 
"ParallelWrite");
+        StreamingPubsubSinkTranslators.translate(
+            transform.getOverriddenTransform(), stepContext, 
context.getInput(transform));
+      }
+    }
+
+    private static void translate(
+        PubsubUnboundedSink overriddenTransform,
+        StepTranslationContext stepContext,
+        PCollection input) {
       stepContext.addInput(PropertyNames.FORMAT, "pubsub");
       if (overriddenTransform.getTopicProvider().isAccessible()) {
         stepContext.addInput(
@@ -1508,7 +1578,7 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
       // Using a GlobalWindowCoder as a place holder because GlobalWindowCoder 
is known coder.
       stepContext.addEncodingInput(
           WindowedValue.getFullCoder(VoidCoder.of(), 
GlobalWindow.Coder.INSTANCE));
-      stepContext.addInput(PropertyNames.PARALLEL_INPUT, 
context.getInput(transform));
+      stepContext.addInput(PropertyNames.PARALLEL_INPUT, input);
     }
   }
 
@@ -1957,6 +2027,54 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
     }
   }
 
+  /**
+   * A replacement {@link PTransform} for {@link PubsubUnboundedSink} when 
using dataflow runner v2.
+   */
+  private static class DataflowWriteToPubsubForRunnerV2
+      extends PTransform<PCollection<PubsubMessage>, PDone> {
+
+    private final PubsubUnboundedSink transform;
+
+    public DataflowWriteToPubsubForRunnerV2(PubsubUnboundedSink transform) {
+      this.transform = transform;
+    }
+
+    @Override
+    public PDone expand(PCollection<PubsubMessage> input) {
+      input
+          .apply(
+              "Output Serialized PubsubMessage Proto",
+              MapElements.into(new TypeDescriptor<byte[]>() {})
+                  .via(new PubsubMessages.ParsePayloadAsPubsubMessageProto()))
+          .setCoder(ByteArrayCoder.of())
+          .apply(new DataflowRunnerV2PubsubSink(transform));
+
+      return PDone.in(input.getPipeline());
+    }
+  }
+
+  /**
+   * A {@link PTransformOverrideFactory} to provide replacement {@link 
PTransform} for {@link
+   * PubsubUnboundedSink} when using dataflow runner v2.
+   */
+  private static class DataflowWriteToPubsubRunnerV2OverrideFactory
+      implements PTransformOverrideFactory<PCollection<PubsubMessage>, PDone, 
PubsubUnboundedSink> {
+
+    @Override
+    public PTransformReplacement<PCollection<PubsubMessage>, PDone> 
getReplacementTransform(
+        AppliedPTransform<PCollection<PubsubMessage>, PDone, 
PubsubUnboundedSink> transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new DataflowWriteToPubsubForRunnerV2(transform.getTransform()));
+    }
+
+    @Override
+    public Map<PCollection<?>, ReplacementOutput> mapOutputs(
+        Map<TupleTag<?>, PCollection<?>> outputs, PDone newOutput) {
+      return Collections.emptyMap();
+    }
+  }
+
   @VisibleForTesting
   static class StreamingShardedWriteFactory<UserT, DestinationT, OutputT>
       implements PTransformOverrideFactory<
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/ExternalRead.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/ExternalRead.java
index 93612ae..250b68c 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/ExternalRead.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/ExternalRead.java
@@ -18,18 +18,17 @@
 package org.apache.beam.sdk.io.gcp.pubsub;
 
 import com.google.auto.service.AutoService;
-import com.google.protobuf.ByteString;
 import java.util.Map;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.coders.ByteArrayCoder;
 import org.apache.beam.sdk.expansion.ExternalTransformRegistrar;
 import org.apache.beam.sdk.io.gcp.pubsub.PubsubIO.PubsubSubscription;
 import org.apache.beam.sdk.io.gcp.pubsub.PubsubIO.PubsubTopic;
+import 
org.apache.beam.sdk.io.gcp.pubsub.PubsubMessages.ParsePayloadAsPubsubMessageProto;
 import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
 import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
 import org.apache.beam.sdk.transforms.ExternalTransformBuilder;
 import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
@@ -111,21 +110,4 @@ public final class ExternalRead implements 
ExternalTransformRegistrar {
       return readBuilder.build();
     }
   }
-
-  // Convert the PubsubMessage to a PubsubMessage proto, then return its 
serialized representation.
-  private static class ParsePayloadAsPubsubMessageProto
-      implements SerializableFunction<PubsubMessage, byte[]> {
-    @Override
-    public byte[] apply(PubsubMessage input) {
-      Map<String, String> attributes = input.getAttributeMap();
-      com.google.pubsub.v1.PubsubMessage.Builder message =
-          com.google.pubsub.v1.PubsubMessage.newBuilder()
-              .setData(ByteString.copyFrom(input.getPayload()));
-      // TODO(BEAM-8085) this should not be null
-      if (attributes != null) {
-        message.putAllAttributes(attributes);
-      }
-      return message.build().toByteArray();
-    }
-  }
 }
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/ExternalWrite.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/ExternalWrite.java
index 4733e32..1258c48 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/ExternalWrite.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/ExternalWrite.java
@@ -18,16 +18,15 @@
 package org.apache.beam.sdk.io.gcp.pubsub;
 
 import com.google.auto.service.AutoService;
-import com.google.protobuf.InvalidProtocolBufferException;
 import java.util.Map;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.expansion.ExternalTransformRegistrar;
 import org.apache.beam.sdk.io.gcp.pubsub.PubsubIO.PubsubTopic;
+import 
org.apache.beam.sdk.io.gcp.pubsub.PubsubMessages.ParsePubsubMessageProtoAsPayload;
 import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
 import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
 import org.apache.beam.sdk.transforms.ExternalTransformBuilder;
 import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PDone;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
@@ -72,7 +71,8 @@ public final class ExternalWrite implements 
ExternalTransformRegistrar {
 
     @Override
     public PTransform<PCollection<byte[]>, PDone> buildExternal(Configuration 
config) {
-      PubsubIO.Write.Builder<byte[]> writeBuilder = 
PubsubIO.Write.newBuilder(new FormatFn());
+      PubsubIO.Write.Builder<byte[]> writeBuilder =
+          PubsubIO.Write.newBuilder(new ParsePubsubMessageProtoAsPayload());
       if (config.topic != null) {
         StaticValueProvider<String> topic = 
StaticValueProvider.of(config.topic);
         writeBuilder.setTopicProvider(NestedValueProvider.of(topic, 
PubsubTopic::fromPath));
@@ -86,17 +86,4 @@ public final class ExternalWrite implements 
ExternalTransformRegistrar {
       return writeBuilder.build();
     }
   }
-
-  private static class FormatFn implements SerializableFunction<byte[], 
PubsubMessage> {
-    @Override
-    public PubsubMessage apply(byte[] input) {
-      try {
-        com.google.pubsub.v1.PubsubMessage message =
-            com.google.pubsub.v1.PubsubMessage.parseFrom(input);
-        return new PubsubMessage(message.getData().toByteArray(), 
message.getAttributesMap());
-      } catch (InvalidProtocolBufferException e) {
-        throw new RuntimeException("Could not decode Pubsub message", e);
-      }
-    }
-  }
 }
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java
new file mode 100644
index 0000000..6d9e295
--- /dev/null
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubMessages.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.pubsub;
+
+import com.google.protobuf.ByteString;
+import com.google.protobuf.InvalidProtocolBufferException;
+import java.util.Map;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+
+/** Common util functions for converting between PubsubMessage proto and 
{@link PubsubMessage}. */
+public class PubsubMessages {
+  // Convert the PubsubMessage to a PubsubMessage proto, then return its 
serialized representation.
+  public static class ParsePayloadAsPubsubMessageProto
+      implements SerializableFunction<PubsubMessage, byte[]> {
+    @Override
+    public byte[] apply(PubsubMessage input) {
+      Map<String, String> attributes = input.getAttributeMap();
+      com.google.pubsub.v1.PubsubMessage.Builder message =
+          com.google.pubsub.v1.PubsubMessage.newBuilder()
+              .setData(ByteString.copyFrom(input.getPayload()));
+      // TODO(BEAM-8085) this should not be null
+      if (attributes != null) {
+        message.putAllAttributes(attributes);
+      }
+      return message.build().toByteArray();
+    }
+  }
+
+  // Convert the serialized PubsubMessage proto to PubsubMessage.
+  public static class ParsePubsubMessageProtoAsPayload
+      implements SerializableFunction<byte[], PubsubMessage> {
+    @Override
+    public PubsubMessage apply(byte[] input) {
+      try {
+        com.google.pubsub.v1.PubsubMessage message =
+            com.google.pubsub.v1.PubsubMessage.parseFrom(input);
+        return new PubsubMessage(message.getData().toByteArray(), 
message.getAttributesMap());
+      } catch (InvalidProtocolBufferException e) {
+        throw new RuntimeException("Could not decode Pubsub message", e);
+      }
+    }
+  }
+}
diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py 
b/sdks/python/apache_beam/io/gcp/pubsub.py
index b0f8bdf..6c90a0b 100644
--- a/sdks/python/apache_beam/io/gcp/pubsub.py
+++ b/sdks/python/apache_beam/io/gcp/pubsub.py
@@ -246,6 +246,10 @@ class _WriteStringsToPubSub(PTransform):
       topic: Cloud Pub/Sub topic in the form "/topics/<project>/<topic>".
     """
     super(_WriteStringsToPubSub, self).__init__()
+    self.with_attributes = False
+    self.id_label = None
+    self.timestamp_attribute = None
+    self.project, self.topic_name = parse_topic(topic)
     self._sink = _PubSubSink(
         topic, id_label=None, with_attributes=False, timestamp_attribute=None)
 
@@ -286,11 +290,13 @@ class WriteToPubSub(PTransform):
     self.with_attributes = with_attributes
     self.id_label = id_label
     self.timestamp_attribute = timestamp_attribute
+    self.project, self.topic_name = parse_topic(topic)
+    self.full_topic = topic
     self._sink = _PubSubSink(
         topic, id_label, with_attributes, timestamp_attribute)
 
   @staticmethod
-  def to_proto_str(element):
+  def message_to_proto_str(element):
     # type: (PubsubMessage) -> bytes
     if not isinstance(element, PubsubMessage):
       raise TypeError(
@@ -298,13 +304,18 @@ class WriteToPubSub(PTransform):
           'value: %r' % (type(element), element))
     return element._to_proto_str()
 
+  @staticmethod
+  def bytes_to_proto_str(element):
+    # type: (bytes) -> bytes
+    msg = pubsub.types.pubsub_pb2.PubsubMessage()
+    msg.data = element
+    return msg.SerializeToString()
+
   def expand(self, pcoll):
     if self.with_attributes:
-      pcoll = pcoll | 'ToProtobuf' >> Map(self.to_proto_str)
-
-    # Without attributes, message data is written as-is. With attributes,
-    # message data + attributes are passed as a serialized protobuf string (see
-    # ``PubsubMessage._to_proto_str`` for exact protobuf message type).
+      pcoll = pcoll | 'ToProtobuf' >> Map(self.message_to_proto_str)
+    else:
+      pcoll = pcoll | 'ToProtobuf' >> Map(self.bytes_to_proto_str)
     pcoll.element_type = bytes
     return pcoll | Write(self._sink)
 
@@ -313,6 +324,16 @@ class WriteToPubSub(PTransform):
     # TODO(BEAM-3812): Use an actual URN here.
     return self.to_runner_api_pickled(context)
 
+  def display_data(self):
+    return {
+        'topic': DisplayDataItem(self.full_topic, label='Pubsub Topic'),
+        'id_label': DisplayDataItem(self.id_label, label='ID Label Attribute'),
+        'with_attributes': DisplayDataItem(
+            True, label='With Attributes').drop_if_none(),
+        'timestamp_attribute': DisplayDataItem(
+            self.timestamp_attribute, label='Timestamp Attribute'),
+    }
+
 
 PROJECT_ID_REGEXP = '[a-z][-a-z0-9:.]{4,61}[a-z0-9]'
 SUBSCRIPTION_REGEXP = 'projects/([^/]+)/subscriptions/(.+)'
@@ -418,10 +439,11 @@ class _PubSubSink(dataflow_io.NativeSink):
                id_label,  # type: Optional[str]
                with_attributes,  # type: bool
                timestamp_attribute  # type: Optional[str]
-              ):
+               ):
     self.coder = coders.BytesCoder()
     self.full_topic = topic
     self.id_label = id_label
+    #TODO(BEAM-10869): Remove with_attributes since we will never look at it.
     self.with_attributes = with_attributes
     self.timestamp_attribute = timestamp_attribute
 
@@ -432,15 +454,5 @@ class _PubSubSink(dataflow_io.NativeSink):
     """Sink format name required for remote execution."""
     return 'pubsub'
 
-  def display_data(self):
-    return {
-        'topic': DisplayDataItem(self.full_topic, label='Pubsub Topic'),
-        'id_label': DisplayDataItem(self.id_label, label='ID Label Attribute'),
-        'with_attributes': DisplayDataItem(
-            self.with_attributes, label='With Attributes').drop_if_none(),
-        'timestamp_attribute': DisplayDataItem(
-            self.timestamp_attribute, label='Timestamp Attribute'),
-    }
-
   def writer(self):
     raise NotImplementedError
diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py 
b/sdks/python/apache_beam/io/gcp/pubsub_test.py
index 77f1bc9..40fe386 100644
--- a/sdks/python/apache_beam/io/gcp/pubsub_test.py
+++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py
@@ -312,16 +312,15 @@ class TestPubSubSource(unittest.TestCase):
 @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
 class TestPubSubSink(unittest.TestCase):
   def test_display_data(self):
-    sink = _PubSubSink(
+    sink = WriteToPubSub(
         'projects/fakeprj/topics/a_topic',
         id_label='id',
-        with_attributes=False,
         timestamp_attribute='time')
     dd = DisplayData.create_from(sink)
     expected_items = [
         DisplayDataItemMatcher('topic', 'projects/fakeprj/topics/a_topic'),
         DisplayDataItemMatcher('id_label', 'id'),
-        DisplayDataItemMatcher('with_attributes', False),
+        DisplayDataItemMatcher('with_attributes', True),
         DisplayDataItemMatcher('timestamp_attribute', 'time'),
     ]
 
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py 
b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index dd92286..204169c 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -1243,6 +1243,7 @@ class DataflowRunner(PipelineRunner):
         # Setting this property signals Dataflow runner to return full
         # PubsubMessages instead of just the data part of the payload.
         step.add_property(PropertyNames.PUBSUB_SERIALIZED_ATTRIBUTES_FN, '')
+
       if transform.source.timestamp_attribute is not None:
         step.add_property(
             PropertyNames.PUBSUB_TIMESTAMP_ATTRIBUTE,
@@ -1344,10 +1345,9 @@ class DataflowRunner(PipelineRunner):
       if transform.sink.id_label:
         step.add_property(
             PropertyNames.PUBSUB_ID_LABEL, transform.sink.id_label)
-      if transform.sink.with_attributes:
-        # Setting this property signals Dataflow runner that the PCollection
-        # contains PubsubMessage objects instead of just raw data.
-        step.add_property(PropertyNames.PUBSUB_SERIALIZED_ATTRIBUTES_FN, '')
+      # Setting this property signals Dataflow runner that the PCollection
+      # contains PubsubMessage objects instead of just raw data.
+      step.add_property(PropertyNames.PUBSUB_SERIALIZED_ATTRIBUTES_FN, '')
       if transform.sink.timestamp_attribute is not None:
         step.add_property(
             PropertyNames.PUBSUB_TIMESTAMP_ATTRIBUTE,
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py 
b/sdks/python/apache_beam/runners/direct/direct_runner.py
index 8f221aa..e57b5f8 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -398,19 +398,19 @@ class _DirectWriteToPubSubFn(DoFn):
   BUFFER_SIZE_ELEMENTS = 100
   FLUSH_TIMEOUT_SECS = BUFFER_SIZE_ELEMENTS * 0.5
 
-  def __init__(self, sink):
-    self.project = sink.project
-    self.short_topic_name = sink.topic_name
-    self.id_label = sink.id_label
-    self.timestamp_attribute = sink.timestamp_attribute
-    self.with_attributes = sink.with_attributes
+  def __init__(self, transform):
+    self.project = transform.project
+    self.short_topic_name = transform.topic_name
+    self.id_label = transform.id_label
+    self.timestamp_attribute = transform.timestamp_attribute
+    self.with_attributes = transform.with_attributes
 
     # TODO(BEAM-4275): Add support for id_label and timestamp_attribute.
-    if sink.id_label:
+    if transform.id_label:
       raise NotImplementedError(
           'DirectRunner: id_label is not supported for '
           'PubSub writes')
-    if sink.timestamp_attribute:
+    if transform.timestamp_attribute:
       raise NotImplementedError(
           'DirectRunner: timestamp_attribute is not '
           'supported for PubSub writes')
@@ -475,8 +475,7 @@ def _get_pubsub_transform_overrides(pipeline_options):
         raise Exception(
             'PubSub I/O is only available in streaming mode '
             '(use the --streaming flag).')
-      return beam.ParDo(
-          _DirectWriteToPubSubFn(applied_ptransform.transform._sink))
+      return beam.ParDo(_DirectWriteToPubSubFn(applied_ptransform.transform))
 
   return [ReadFromPubSubOverride(), WriteToPubSubOverride()]
 

Reply via email to