This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 1a653c5acbe6ecd347ffbe3686eb324f88c0596a
Author: Dian Fu <dia...@apache.org>
AuthorDate: Fri Jul 2 15:12:09 2021 +0800

    [FLINK-23401][python] Refactor the construction of transformation into 
getTransforms
    
    This closes #16541.
---
 .../java/org/apache/flink/python/Constants.java    |  8 ++++
 .../beam/BeamDataStreamPythonFunctionRunner.java   | 35 ++++++++++++--
 .../python/beam/BeamPythonFunctionRunner.java      | 56 +++++++---------------
 .../python/beam/BeamTablePythonFunctionRunner.java | 31 ++++++++++--
 4 files changed, 85 insertions(+), 45 deletions(-)

diff --git a/flink-python/src/main/java/org/apache/flink/python/Constants.java 
b/flink-python/src/main/java/org/apache/flink/python/Constants.java
index 54d1d06..555ccc2 100644
--- a/flink-python/src/main/java/org/apache/flink/python/Constants.java
+++ b/flink-python/src/main/java/org/apache/flink/python/Constants.java
@@ -27,4 +27,12 @@ public class Constants {
 
     // coder urns
     public static final String FLINK_CODER_URN = "flink:coder:v1";
+
+    // execution graph
+    public static final String TRANSFORM_ID = "transform";
+    public static final String MAIN_INPUT_NAME = "input";
+    public static final String MAIN_OUTPUT_NAME = "output";
+
+    public static final String INPUT_COLLECTION_ID = "input";
+    public static final String OUTPUT_COLLECTION_ID = "output";
 }
diff --git 
a/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamDataStreamPythonFunctionRunner.java
 
b/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamDataStreamPythonFunctionRunner.java
index 5024ed0..a9f5b91 100644
--- 
a/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamDataStreamPythonFunctionRunner.java
+++ 
b/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamDataStreamPythonFunctionRunner.java
@@ -25,11 +25,21 @@ import org.apache.flink.python.env.PythonEnvironmentManager;
 import org.apache.flink.python.metric.FlinkMetricContainer;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.beam.model.pipeline.v1.RunnerApi;
 
 import javax.annotation.Nullable;
 
+import java.util.Collections;
 import java.util.Map;
 
+import static org.apache.flink.python.Constants.INPUT_COLLECTION_ID;
+import static org.apache.flink.python.Constants.MAIN_INPUT_NAME;
+import static org.apache.flink.python.Constants.MAIN_OUTPUT_NAME;
+import static org.apache.flink.python.Constants.OUTPUT_COLLECTION_ID;
+import static org.apache.flink.python.Constants.TRANSFORM_ID;
+
 /**
  * {@link BeamDataStreamPythonFunctionRunner} is responsible for starting a 
beam python harness to
  * execute user defined python function.
@@ -37,6 +47,7 @@ import java.util.Map;
 @Internal
 public class BeamDataStreamPythonFunctionRunner extends 
BeamPythonFunctionRunner {
 
+    private final String functionUrn;
     private final FlinkFnApi.UserDefinedDataStreamFunction 
userDefinedDataStreamFunction;
 
     public BeamDataStreamPythonFunctionRunner(
@@ -56,7 +67,6 @@ public class BeamDataStreamPythonFunctionRunner extends 
BeamPythonFunctionRunner
         super(
                 taskName,
                 environmentManager,
-                functionUrn,
                 jobOptions,
                 flinkMetricContainer,
                 stateBackend,
@@ -66,11 +76,28 @@ public class BeamDataStreamPythonFunctionRunner extends 
BeamPythonFunctionRunner
                 managedMemoryFraction,
                 inputCoderDescriptor,
                 outputCoderDescriptor);
-        this.userDefinedDataStreamFunction = userDefinedDataStreamFunction;
+        this.functionUrn = Preconditions.checkNotNull(functionUrn);
+        this.userDefinedDataStreamFunction =
+                Preconditions.checkNotNull(userDefinedDataStreamFunction);
     }
 
     @Override
-    protected byte[] getUserDefinedFunctionsProtoBytes() {
-        return this.userDefinedDataStreamFunction.toByteArray();
+    protected Map<String, RunnerApi.PTransform> getTransforms() {
+        return Collections.singletonMap(
+                TRANSFORM_ID,
+                RunnerApi.PTransform.newBuilder()
+                        .setUniqueName(TRANSFORM_ID)
+                        .setSpec(
+                                RunnerApi.FunctionSpec.newBuilder()
+                                        .setUrn(functionUrn)
+                                        .setPayload(
+                                                
org.apache.beam.vendor.grpc.v1p26p0.com.google
+                                                        
.protobuf.ByteString.copyFrom(
+                                                        
userDefinedDataStreamFunction
+                                                                
.toByteArray()))
+                                        .build())
+                        .putInputs(MAIN_INPUT_NAME, INPUT_COLLECTION_ID)
+                        .putOutputs(MAIN_OUTPUT_NAME, OUTPUT_COLLECTION_ID)
+                        .build());
     }
 }
diff --git 
a/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
 
b/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
index d485d6c..89cd6b7 100644
--- 
a/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
+++ 
b/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
@@ -81,6 +81,9 @@ import java.util.Map;
 import java.util.concurrent.LinkedBlockingQueue;
 
 import static org.apache.beam.runners.core.construction.BeamUrns.getUrn;
+import static org.apache.flink.python.Constants.INPUT_COLLECTION_ID;
+import static org.apache.flink.python.Constants.OUTPUT_COLLECTION_ID;
+import static org.apache.flink.python.Constants.TRANSFORM_ID;
 import static org.apache.flink.streaming.api.utils.ProtoUtils.createCoderProto;
 
 /** A {@link BeamPythonFunctionRunner} used to execute Python functions. */
@@ -88,13 +91,6 @@ import static 
org.apache.flink.streaming.api.utils.ProtoUtils.createCoderProto;
 public abstract class BeamPythonFunctionRunner implements PythonFunctionRunner 
{
     protected static final Logger LOG = 
LoggerFactory.getLogger(BeamPythonFunctionRunner.class);
 
-    private static final String INPUT_ID = "input";
-    private static final String OUTPUT_ID = "output";
-    private static final String TRANSFORM_ID = "transform";
-
-    private static final String MAIN_INPUT_NAME = "input";
-    private static final String MAIN_OUTPUT_NAME = "output";
-
     private static final String INPUT_CODER_ID = "input_coder";
     private static final String OUTPUT_CODER_ID = "output_coder";
     private static final String WINDOW_CODER_ID = "window_coder";
@@ -109,9 +105,6 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
     /** The Python execution environment manager. */
     private final PythonEnvironmentManager environmentManager;
 
-    /** The urn which represents the function kind to be executed. */
-    private final String functionUrn;
-
     /** The options used to configure the Python worker process. */
     private final Map<String, String> jobOptions;
 
@@ -173,7 +166,6 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
     public BeamPythonFunctionRunner(
             String taskName,
             PythonEnvironmentManager environmentManager,
-            String functionUrn,
             Map<String, String> jobOptions,
             @Nullable FlinkMetricContainer flinkMetricContainer,
             @Nullable KeyedStateBackend keyedStateBackend,
@@ -185,7 +177,6 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
             FlinkFnApi.CoderInfoDescriptor outputCoderDescriptor) {
         this.taskName = Preconditions.checkNotNull(taskName);
         this.environmentManager = 
Preconditions.checkNotNull(environmentManager);
-        this.functionUrn = Preconditions.checkNotNull(functionUrn);
         this.jobOptions = Preconditions.checkNotNull(jobOptions);
         this.flinkMetricContainer = flinkMetricContainer;
         this.stateRequestHandler =
@@ -379,36 +370,20 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
     @SuppressWarnings("unchecked")
     private ExecutableStage createExecutableStage(RunnerApi.Environment 
environment)
             throws Exception {
-        RunnerApi.Components components =
+        RunnerApi.Components.Builder componentsBuilder =
                 RunnerApi.Components.newBuilder()
                         .putPcollections(
-                                INPUT_ID,
+                                INPUT_COLLECTION_ID,
                                 RunnerApi.PCollection.newBuilder()
                                         
.setWindowingStrategyId(WINDOW_STRATEGY)
                                         .setCoderId(INPUT_CODER_ID)
                                         .build())
                         .putPcollections(
-                                OUTPUT_ID,
+                                OUTPUT_COLLECTION_ID,
                                 RunnerApi.PCollection.newBuilder()
                                         
.setWindowingStrategyId(WINDOW_STRATEGY)
                                         .setCoderId(OUTPUT_CODER_ID)
                                         .build())
-                        .putTransforms(
-                                TRANSFORM_ID,
-                                RunnerApi.PTransform.newBuilder()
-                                        .setUniqueName(TRANSFORM_ID)
-                                        .setSpec(
-                                                
RunnerApi.FunctionSpec.newBuilder()
-                                                        .setUrn(functionUrn)
-                                                        .setPayload(
-                                                                
org.apache.beam.vendor.grpc.v1p26p0
-                                                                        
.com.google.protobuf
-                                                                        
.ByteString.copyFrom(
-                                                                        
getUserDefinedFunctionsProtoBytes()))
-                                                        .build())
-                                        .putInputs(MAIN_INPUT_NAME, INPUT_ID)
-                                        .putOutputs(MAIN_OUTPUT_NAME, 
OUTPUT_ID)
-                                        .build())
                         .putWindowingStrategies(
                                 WINDOW_STRATEGY,
                                 RunnerApi.WindowingStrategy.newBuilder()
@@ -416,11 +391,15 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
                                         .build())
                         .putCoders(INPUT_CODER_ID, 
createCoderProto(inputCoderDescriptor))
                         .putCoders(OUTPUT_CODER_ID, 
createCoderProto(outputCoderDescriptor))
-                        .putCoders(WINDOW_CODER_ID, getWindowCoderProto())
-                        .build();
+                        .putCoders(WINDOW_CODER_ID, getWindowCoderProto());
+
+        getTransforms().forEach(componentsBuilder::putTransforms);
+        RunnerApi.Components components = componentsBuilder.build();
 
         PipelineNode.PCollectionNode input =
-                PipelineNode.pCollection(INPUT_ID, 
components.getPcollectionsOrThrow(INPUT_ID));
+                PipelineNode.pCollection(
+                        INPUT_COLLECTION_ID,
+                        
components.getPcollectionsOrThrow(INPUT_COLLECTION_ID));
         List<SideInputReference> sideInputs = Collections.EMPTY_LIST;
         List<UserStateReference> userStates = Collections.EMPTY_LIST;
         List<TimerReference> timers = Collections.EMPTY_LIST;
@@ -431,7 +410,8 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
         List<PipelineNode.PCollectionNode> outputs =
                 Collections.singletonList(
                         PipelineNode.pCollection(
-                                OUTPUT_ID, 
components.getPcollectionsOrThrow(OUTPUT_ID)));
+                                OUTPUT_COLLECTION_ID,
+                                
components.getPcollectionsOrThrow(OUTPUT_COLLECTION_ID)));
         return ImmutableExecutableStage.of(
                 components,
                 environment,
@@ -459,14 +439,14 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
                         .setPayload(
                                 
org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString
                                         .copyFrom(baos.toByteArray()))
-                        .setInputOrOutputId(INPUT_ID)
+                        .setInputOrOutputId(INPUT_COLLECTION_ID)
                         .build(),
                 RunnerApi.ExecutableStagePayload.WireCoderSetting.newBuilder()
                         
.setUrn(getUrn(RunnerApi.StandardCoders.Enum.PARAM_WINDOWED_VALUE))
                         .setPayload(
                                 
org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString
                                         .copyFrom(baos.toByteArray()))
-                        .setInputOrOutputId(OUTPUT_ID)
+                        .setInputOrOutputId(OUTPUT_COLLECTION_ID)
                         .build());
     }
 
@@ -480,7 +460,7 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
                 .build();
     }
 
-    protected abstract byte[] getUserDefinedFunctionsProtoBytes();
+    protected abstract Map<String, RunnerApi.PTransform> getTransforms();
 
     // ------------------------------------------------------------------------
     // Construct RemoteBundler
diff --git 
a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/beam/BeamTablePythonFunctionRunner.java
 
b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/beam/BeamTablePythonFunctionRunner.java
index c4943c5..6e07a75 100644
--- 
a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/beam/BeamTablePythonFunctionRunner.java
+++ 
b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/beam/BeamTablePythonFunctionRunner.java
@@ -29,13 +29,24 @@ import 
org.apache.flink.streaming.api.runners.python.beam.BeamPythonFunctionRunn
 import org.apache.flink.util.Preconditions;
 
 import com.google.protobuf.GeneratedMessageV3;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
 
+import java.util.Collections;
 import java.util.Map;
 
+import static org.apache.flink.python.Constants.INPUT_COLLECTION_ID;
+import static org.apache.flink.python.Constants.MAIN_INPUT_NAME;
+import static org.apache.flink.python.Constants.MAIN_OUTPUT_NAME;
+import static org.apache.flink.python.Constants.OUTPUT_COLLECTION_ID;
+import static org.apache.flink.python.Constants.TRANSFORM_ID;
+
 /** A {@link BeamTablePythonFunctionRunner} used to execute Python functions 
in Table API. */
 @Internal
 public class BeamTablePythonFunctionRunner extends BeamPythonFunctionRunner {
 
+    /** The urn which represents the function kind to be executed. */
+    private final String functionUrn;
+
     private final GeneratedMessageV3 userDefinedFunctionProto;
 
     public BeamTablePythonFunctionRunner(
@@ -55,7 +66,6 @@ public class BeamTablePythonFunctionRunner extends 
BeamPythonFunctionRunner {
         super(
                 taskName,
                 environmentManager,
-                functionUrn,
                 jobOptions,
                 flinkMetricContainer,
                 keyedStateBackend,
@@ -65,11 +75,26 @@ public class BeamTablePythonFunctionRunner extends 
BeamPythonFunctionRunner {
                 managedMemoryFraction,
                 inputCoderDescriptor,
                 outputCoderDescriptor);
+        this.functionUrn = Preconditions.checkNotNull(functionUrn);
         this.userDefinedFunctionProto = 
Preconditions.checkNotNull(userDefinedFunctionProto);
     }
 
     @Override
-    protected byte[] getUserDefinedFunctionsProtoBytes() {
-        return userDefinedFunctionProto.toByteArray();
+    protected Map<String, RunnerApi.PTransform> getTransforms() {
+        return Collections.singletonMap(
+                TRANSFORM_ID,
+                RunnerApi.PTransform.newBuilder()
+                        .setUniqueName(TRANSFORM_ID)
+                        .setSpec(
+                                RunnerApi.FunctionSpec.newBuilder()
+                                        .setUrn(functionUrn)
+                                        .setPayload(
+                                                
org.apache.beam.vendor.grpc.v1p26p0.com.google
+                                                        
.protobuf.ByteString.copyFrom(
+                                                        
userDefinedFunctionProto.toByteArray()))
+                                        .build())
+                        .putInputs(MAIN_INPUT_NAME, INPUT_COLLECTION_ID)
+                        .putOutputs(MAIN_OUTPUT_NAME, OUTPUT_COLLECTION_ID)
+                        .build());
     }
 }

Reply via email to