This is an automated email from the ASF dual-hosted git repository.

libenchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 42b7e74ab20 [FLINK-33263][table-planner] Implement ParallelismProvider 
for sources in the table planner
42b7e74ab20 is described below

commit 42b7e74ab20785289b62f5dd68d566995ba9dcfc
Author: SuDewei <sudewei....@outlook.com>
AuthorDate: Thu Jan 18 16:05:40 2024 +0800

    [FLINK-33263][table-planner] Implement ParallelismProvider for sources in 
the table planner
    
    Close apache/flink#24128
---
 .../org/apache/flink/api/dag/Transformation.java   |  15 ++
 .../streaming/api/graph/StreamGraphGenerator.java  |   3 +
 .../SourceTransformationWrapper.java               |  72 ++++++++++
 .../exec/common/CommonExecTableSourceScan.java     | 154 ++++++++++++++++++---
 .../table/planner/delegation/BatchPlanner.scala    |   2 +-
 .../table/planner/delegation/PlannerBase.scala     |   3 +-
 .../table/planner/delegation/StreamPlanner.scala   |   2 +-
 .../planner/factories/TestValuesTableFactory.java  |  33 +++--
 .../planner/plan/stream/sql/TableScanTest.xml      |  42 ++++++
 .../planner/plan/stream/sql/TableScanTest.scala    |  38 +++++
 .../runtime/stream/sql/TableSourceITCase.scala     |  80 +++++++++++
 .../flink/table/planner/utils/TableTestBase.scala  |  51 ++++++-
 12 files changed, 463 insertions(+), 32 deletions(-)

diff --git 
a/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java 
b/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java
index a0448697dd1..6256f9624f6 100644
--- a/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java
+++ b/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java
@@ -19,6 +19,7 @@
 package org.apache.flink.api.dag;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.InvalidTypesException;
 import org.apache.flink.api.common.operators.ResourceSpec;
@@ -602,6 +603,20 @@ public abstract class Transformation<T> {
                 + '}';
     }
 
+    @VisibleForTesting
+    public String toStringWithoutId() {
+        return getClass().getSimpleName()
+                + "{"
+                + "name='"
+                + name
+                + '\''
+                + ", outputType="
+                + outputType
+                + ", parallelism="
+                + parallelism
+                + '}';
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) {
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
index 5929a2a5e8e..8e267ff84d6 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
@@ -65,6 +65,7 @@ import 
org.apache.flink.streaming.api.transformations.ReduceTransformation;
 import org.apache.flink.streaming.api.transformations.SideOutputTransformation;
 import org.apache.flink.streaming.api.transformations.SinkTransformation;
 import org.apache.flink.streaming.api.transformations.SourceTransformation;
+import 
org.apache.flink.streaming.api.transformations.SourceTransformationWrapper;
 import 
org.apache.flink.streaming.api.transformations.TimestampsAndWatermarksTransformation;
 import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
 import org.apache.flink.streaming.api.transformations.UnionTransformation;
@@ -553,6 +554,8 @@ public class StreamGraphGenerator {
             transformedIds = transformFeedback((FeedbackTransformation<?>) 
transform);
         } else if (transform instanceof CoFeedbackTransformation<?>) {
             transformedIds = transformCoFeedback((CoFeedbackTransformation<?>) 
transform);
+        } else if (transform instanceof SourceTransformationWrapper<?>) {
+            transformedIds = transform(((SourceTransformationWrapper<?>) 
transform).getInput());
         } else {
             throw new IllegalStateException("Unknown transformation: " + 
transform);
         }
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformationWrapper.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformationWrapper.java
new file mode 100644
index 00000000000..d536000fde2
--- /dev/null
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformationWrapper.java
@@ -0,0 +1,72 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package org.apache.flink.streaming.api.transformations;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.dag.Transformation;
+import org.apache.flink.streaming.api.graph.TransformationTranslator;
+
+import org.apache.flink.shaded.guava31.com.google.common.collect.Lists;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * This Transformation is a phantom transformation which is used to expose a 
default parallelism to
+ * downstream.
+ *
+ * <p>It is used only when the parallelism of the source transformation 
differs from the default
+ * parallelism, ensuring that the parallelism of downstream operations is not 
affected.
+ *
+ * <p>Moreover, this transformation does not have a corresponding {@link 
TransformationTranslator},
+ * meaning it will not become a node in the StreamGraph.
+ *
+ * @param <T> The type of the elements in the input {@code Transformation}
+ */
+@Internal
+public class SourceTransformationWrapper<T> extends Transformation<T> {
+
+    private final Transformation<T> input;
+
+    public SourceTransformationWrapper(Transformation<T> input) {
+        super(
+                "ChangeToDefaultParallel",
+                input.getOutputType(),
+                ExecutionConfig.PARALLELISM_DEFAULT);
+        this.input = input;
+    }
+
+    public Transformation<T> getInput() {
+        return input;
+    }
+
+    @Override
+    public List<Transformation<?>> getTransitivePredecessors() {
+        List<Transformation<?>> result = Lists.newArrayList();
+        result.add(this);
+        result.addAll(input.getTransitivePredecessors());
+        return result;
+    }
+
+    @Override
+    public List<Transformation<?>> getInputs() {
+        return Collections.singletonList(input);
+    }
+}
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecTableSourceScan.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecTableSourceScan.java
index dc69543cd28..be5b46ba973 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecTableSourceScan.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecTableSourceScan.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.table.planner.plan.nodes.exec.common;
 
+import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.eventtime.WatermarkStrategy;
 import org.apache.flink.api.common.io.InputFormat;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -30,6 +31,13 @@ import 
org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.api.operators.StreamSource;
 import 
org.apache.flink.streaming.api.transformations.LegacySourceTransformation;
+import org.apache.flink.streaming.api.transformations.PartitionTransformation;
+import 
org.apache.flink.streaming.api.transformations.SourceTransformationWrapper;
+import 
org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner;
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.connector.ChangelogMode;
+import org.apache.flink.table.connector.ParallelismProvider;
 import org.apache.flink.table.connector.ProviderContext;
 import org.apache.flink.table.connector.source.DataStreamScanProvider;
 import org.apache.flink.table.connector.source.InputFormatProvider;
@@ -48,17 +56,22 @@ import 
org.apache.flink.table.planner.plan.nodes.exec.MultipleTransformationTran
 import 
org.apache.flink.table.planner.plan.nodes.exec.spec.DynamicTableSourceSpec;
 import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecNode;
 import 
org.apache.flink.table.planner.plan.nodes.exec.utils.TransformationMetadata;
+import org.apache.flink.table.planner.plan.utils.KeySelectorUtil;
 import org.apache.flink.table.planner.utils.ShortcutUtils;
 import 
org.apache.flink.table.runtime.connector.source.ScanRuntimeProviderContext;
+import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
 import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
 import org.apache.flink.table.types.logical.LogicalType;
 import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.types.RowKind;
 
 import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
 
 import java.util.List;
 import java.util.Optional;
 
+import static 
org.apache.flink.runtime.state.KeyGroupRangeAssignment.DEFAULT_LOWER_BOUND_MAX_PARALLELISM;
+
 /**
  * Base {@link ExecNode} to read data from an external source defined by a 
{@link ScanTableSource}.
  */
@@ -96,6 +109,7 @@ public abstract class CommonExecTableSourceScan extends 
ExecNodeBase<RowData>
     @Override
     protected Transformation<RowData> translateToPlanInternal(
             PlannerBase planner, ExecNodeConfig config) {
+        final Transformation<RowData> sourceTransform;
         final StreamExecutionEnvironment env = planner.getExecEnv();
         final TransformationMetadata meta = 
createTransformationMeta(SOURCE_TRANSFORMATION, config);
         final InternalTypeInfo<RowData> outputTypeInfo =
@@ -105,54 +119,149 @@ public abstract class CommonExecTableSourceScan extends 
ExecNodeBase<RowData>
                         planner.getFlinkContext(), 
ShortcutUtils.unwrapTypeFactory(planner));
         ScanTableSource.ScanRuntimeProvider provider =
                 
tableSource.getScanRuntimeProvider(ScanRuntimeProviderContext.INSTANCE);
+        final int sourceParallelism = deriveSourceParallelism(provider);
+        final boolean sourceParallelismConfigured = 
isParallelismConfigured(provider);
         if (provider instanceof SourceFunctionProvider) {
             final SourceFunctionProvider sourceFunctionProvider = 
(SourceFunctionProvider) provider;
             final SourceFunction<RowData> function = 
sourceFunctionProvider.createSourceFunction();
-            final Transformation<RowData> transformation =
+            sourceTransform =
                     createSourceFunctionTransformation(
                             env,
                             function,
                             sourceFunctionProvider.isBounded(),
                             meta.getName(),
-                            outputTypeInfo);
-            return meta.fill(transformation);
+                            outputTypeInfo,
+                            sourceParallelism,
+                            sourceParallelismConfigured);
+            if (function instanceof ParallelSourceFunction && 
sourceParallelismConfigured) {
+                meta.fill(sourceTransform);
+                return new SourceTransformationWrapper<>(sourceTransform);
+            } else {
+                return meta.fill(sourceTransform);
+            }
         } else if (provider instanceof InputFormatProvider) {
             final InputFormat<RowData, ?> inputFormat =
                     ((InputFormatProvider) provider).createInputFormat();
-            final Transformation<RowData> transformation =
+            sourceTransform =
                     createInputFormatTransformation(
                             env, inputFormat, outputTypeInfo, meta.getName());
-            return meta.fill(transformation);
+            meta.fill(sourceTransform);
         } else if (provider instanceof SourceProvider) {
             final Source<RowData, ?, ?> source = ((SourceProvider) 
provider).createSource();
             // TODO: Push down watermark strategy to source scan
-            final Transformation<RowData> transformation =
+            sourceTransform =
                     env.fromSource(
                                     source,
                                     WatermarkStrategy.noWatermarks(),
                                     meta.getName(),
                                     outputTypeInfo)
                             .getTransformation();
-            return meta.fill(transformation);
+            meta.fill(sourceTransform);
         } else if (provider instanceof DataStreamScanProvider) {
-            Transformation<RowData> transformation =
+            sourceTransform =
                     ((DataStreamScanProvider) provider)
                             .produceDataStream(createProviderContext(config), 
env)
                             .getTransformation();
-            meta.fill(transformation);
-            transformation.setOutputType(outputTypeInfo);
-            return transformation;
+            meta.fill(sourceTransform);
+            sourceTransform.setOutputType(outputTypeInfo);
         } else if (provider instanceof TransformationScanProvider) {
-            final Transformation<RowData> transformation =
+            sourceTransform =
                     ((TransformationScanProvider) provider)
                             
.createTransformation(createProviderContext(config));
-            meta.fill(transformation);
-            transformation.setOutputType(outputTypeInfo);
-            return transformation;
+            meta.fill(sourceTransform);
+            sourceTransform.setOutputType(outputTypeInfo);
         } else {
             throw new UnsupportedOperationException(
                     provider.getClass().getSimpleName() + " is unsupported 
now.");
         }
+
+        if (sourceParallelismConfigured) {
+            return applySourceTransformationWrapper(
+                    sourceTransform,
+                    planner.getFlinkContext().getClassLoader(),
+                    outputTypeInfo,
+                    config,
+                    tableSource.getChangelogMode(),
+                    sourceParallelism);
+        } else {
+            return sourceTransform;
+        }
+    }
+
+    private boolean 
isParallelismConfigured(ScanTableSource.ScanRuntimeProvider runtimeProvider) {
+        return runtimeProvider instanceof ParallelismProvider
+                && ((ParallelismProvider) 
runtimeProvider).getParallelism().isPresent();
+    }
+
+    private int deriveSourceParallelism(ScanTableSource.ScanRuntimeProvider 
runtimeProvider) {
+        if (isParallelismConfigured(runtimeProvider)) {
+            int sourceParallelism = ((ParallelismProvider) 
runtimeProvider).getParallelism().get();
+            if (sourceParallelism <= 0) {
+                throw new TableException(
+                        String.format(
+                                "Invalid configured parallelism %s for table 
'%s'.",
+                                sourceParallelism,
+                                tableSourceSpec
+                                        .getContextResolvedTable()
+                                        .getIdentifier()
+                                        .asSummaryString()));
+            }
+            return sourceParallelism;
+        } else {
+            return ExecutionConfig.PARALLELISM_DEFAULT;
+        }
+    }
+
+    protected RowType getPhysicalRowType(ResolvedSchema schema) {
+        return (RowType) schema.toPhysicalRowDataType().getLogicalType();
+    }
+
+    protected int[] getPrimaryKeyIndices(RowType sourceRowType, ResolvedSchema 
schema) {
+        return schema.getPrimaryKey()
+                .map(k -> 
k.getColumns().stream().mapToInt(sourceRowType::getFieldIndex).toArray())
+                .orElse(new int[0]);
+    }
+
+    private Transformation<RowData> applySourceTransformationWrapper(
+            Transformation<RowData> sourceTransform,
+            ClassLoader classLoader,
+            InternalTypeInfo<RowData> outputTypeInfo,
+            ExecNodeConfig config,
+            ChangelogMode changelogMode,
+            int sourceParallelism) {
+        sourceTransform.setParallelism(sourceParallelism, true);
+        Transformation<RowData> sourceTransformationWrapper =
+                new SourceTransformationWrapper<>(sourceTransform);
+
+        if (!changelogMode.containsOnly(RowKind.INSERT)) {
+            final ResolvedSchema schema =
+                    
tableSourceSpec.getContextResolvedTable().getResolvedSchema();
+            final RowType physicalRowType = getPhysicalRowType(schema);
+            final int[] primaryKeys = getPrimaryKeyIndices(physicalRowType, 
schema);
+            final boolean hasPk = primaryKeys.length > 0;
+            if (!hasPk) {
+                throw new TableException(
+                        String.format(
+                                "Configured parallelism %s for upsert table 
'%s' while can not find primary key field. "
+                                        + "This is a bug, please file an 
issue.",
+                                sourceParallelism,
+                                tableSourceSpec
+                                        .getContextResolvedTable()
+                                        .getIdentifier()
+                                        .asSummaryString()));
+            }
+            final RowDataKeySelector selector =
+                    KeySelectorUtil.getRowDataSelector(classLoader, 
primaryKeys, outputTypeInfo);
+            final KeyGroupStreamPartitioner<RowData, RowData> partitioner =
+                    new KeyGroupStreamPartitioner<>(selector, 
DEFAULT_LOWER_BOUND_MAX_PARALLELISM);
+            Transformation<RowData> partitionedTransform =
+                    new PartitionTransformation<>(sourceTransformationWrapper, 
partitioner);
+            createTransformationMeta("partitioner", "Partitioner", 
"Partitioner", config)
+                    .fill(partitionedTransform);
+            return partitionedTransform;
+        } else {
+            return sourceTransformationWrapper;
+        }
     }
 
     private ProviderContext createProviderContext(ExecNodeConfig config) {
@@ -178,17 +287,22 @@ public abstract class CommonExecTableSourceScan extends 
ExecNodeBase<RowData>
             SourceFunction<RowData> function,
             boolean isBounded,
             String operatorName,
-            TypeInformation<RowData> outputTypeInfo) {
+            TypeInformation<RowData> outputTypeInfo,
+            int sourceParallelism,
+            boolean sourceParallelismConfigured) {
 
         env.clean(function);
 
         final int parallelism;
-        boolean parallelismConfigured = false;
         if (function instanceof ParallelSourceFunction) {
-            parallelism = env.getParallelism();
+            if (sourceParallelismConfigured) {
+                parallelism = sourceParallelism;
+            } else {
+                parallelism = env.getParallelism();
+            }
         } else {
             parallelism = 1;
-            parallelismConfigured = true;
+            sourceParallelismConfigured = true;
         }
 
         final Boundedness boundedness;
@@ -205,7 +319,7 @@ public abstract class CommonExecTableSourceScan extends 
ExecNodeBase<RowData>
                 outputTypeInfo,
                 parallelism,
                 boundedness,
-                parallelismConfigured);
+                sourceParallelismConfigured);
     }
 
     /**
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala
index bb4c1b75a28..cea10f7bb81 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala
@@ -84,7 +84,7 @@ class BatchPlanner(
     processors
   }
 
-  override protected def translateToPlan(execGraph: ExecNodeGraph): 
util.List[Transformation[_]] = {
+  override def translateToPlan(execGraph: ExecNodeGraph): 
util.List[Transformation[_]] = {
     beforeTranslation()
     val planner = createDummyPlanner()
 
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
index 45788e6278e..b36edaa21d7 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
@@ -367,7 +367,8 @@ abstract class PlannerBase(
    * @return
    *   The [[Transformation]] DAG that corresponds to the node DAG.
    */
-  protected def translateToPlan(execGraph: ExecNodeGraph): 
util.List[Transformation[_]]
+  @VisibleForTesting
+  def translateToPlan(execGraph: ExecNodeGraph): util.List[Transformation[_]]
 
   def addExtraTransformation(transformation: Transformation[_]): Unit = {
     if (!extraTransformations.contains(transformation)) {
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/StreamPlanner.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/StreamPlanner.scala
index fb32326f117..894a37c8cf9 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/StreamPlanner.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/StreamPlanner.scala
@@ -78,7 +78,7 @@ class StreamPlanner(
 
   override protected def getExecNodeGraphProcessors: 
Seq[ExecNodeGraphProcessor] = Seq()
 
-  override protected def translateToPlan(execGraph: ExecNodeGraph): 
util.List[Transformation[_]] = {
+  override def translateToPlan(execGraph: ExecNodeGraph): 
util.List[Transformation[_]] = {
     beforeTranslation()
     val planner = createDummyPlanner()
     val transformations = execGraph.getRootNodes.map {
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
index 3dbf4d5b9c0..db64847b75b 100644
--- 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
@@ -466,6 +466,8 @@ public final class TestValuesTableFactory
     private static final ConfigOption<String> SINK_CHANGELOG_MODE_ENFORCED =
             
ConfigOptions.key("sink-changelog-mode-enforced").stringType().noDefaultValue();
 
+    private static final ConfigOption<Integer> SOURCE_PARALLELISM = 
FactoryUtil.SOURCE_PARALLELISM;
+
     private static final ConfigOption<Integer> SINK_PARALLELISM = 
FactoryUtil.SINK_PARALLELISM;
 
     @Override
@@ -497,6 +499,7 @@ public final class TestValuesTableFactory
         int lookupThreshold = helper.getOptions().get(LOOKUP_THRESHOLD);
         int sleepAfterElements = 
helper.getOptions().get(SOURCE_SLEEP_AFTER_ELEMENTS);
         long sleepTimeMillis = 
helper.getOptions().get(SOURCE_SLEEP_TIME).toMillis();
+        Integer parallelism = helper.getOptions().get(SOURCE_PARALLELISM);
         DefaultLookupCache cache = null;
         if 
(helper.getOptions().get(CACHE_TYPE).equals(LookupOptions.LookupCacheType.PARTIAL))
 {
             cache = DefaultLookupCache.fromConfig(helper.getOptions());
@@ -571,7 +574,8 @@ public final class TestValuesTableFactory
                         Long.MAX_VALUE,
                         partitions,
                         readableMetadata,
-                        null);
+                        null,
+                        parallelism);
             }
 
             if (disableLookup) {
@@ -746,6 +750,7 @@ public final class TestValuesTableFactory
                         SOURCE_NUM_ELEMENT_TO_SKIP,
                         SOURCE_SLEEP_AFTER_ELEMENTS,
                         SOURCE_SLEEP_TIME,
+                        SOURCE_PARALLELISM,
                         INTERNAL_DATA,
                         CACHE_TYPE,
                         PARTIAL_CACHE_EXPIRE_AFTER_ACCESS,
@@ -916,6 +921,7 @@ public final class TestValuesTableFactory
         private @Nullable int[] groupingSet;
         private List<AggregateExpression> aggregateExpressions;
         private List<String> acceptedPartitionFilterFields;
+        private final Integer parallelism;
 
         private TestValuesScanTableSourceWithoutProjectionPushDown(
                 DataType producedDataType,
@@ -934,7 +940,8 @@ public final class TestValuesTableFactory
                 long limit,
                 List<Map<String, String>> allPartitions,
                 Map<String, DataType> readableMetadata,
-                @Nullable int[] projectedMetadataFields) {
+                @Nullable int[] projectedMetadataFields,
+                @Nullable Integer parallelism) {
             this.producedDataType = producedDataType;
             this.changelogMode = changelogMode;
             this.boundedness = boundedness;
@@ -954,6 +961,7 @@ public final class TestValuesTableFactory
             this.projectedMetadataFields = projectedMetadataFields;
             this.groupingSet = null;
             this.aggregateExpressions = Collections.emptyList();
+            this.parallelism = parallelism;
         }
 
         @Override
@@ -987,7 +995,7 @@ public final class TestValuesTableFactory
                             sourceFunction = new 
FromElementsFunction<>(serializer, values);
                         }
                         return SourceFunctionProvider.of(
-                                sourceFunction, boundedness == 
Boundedness.BOUNDED);
+                                sourceFunction, boundedness == 
Boundedness.BOUNDED, parallelism);
                     } catch (IOException e) {
                         throw new TableException("Fail to init source 
function", e);
                     }
@@ -999,7 +1007,8 @@ public final class TestValuesTableFactory
                             terminating == TerminatingLogic.FINITE,
                             "Values Source doesn't support infinite 
InputFormat.");
                     Collection<RowData> values = convertToRowData(converter);
-                    return InputFormatProvider.of(new 
CollectionInputFormat<>(values, serializer));
+                    return InputFormatProvider.of(
+                            new CollectionInputFormat<>(values, serializer), 
parallelism);
                 case "DataStream":
                     checkArgument(
                             !failingSource,
@@ -1024,6 +1033,11 @@ public final class TestValuesTableFactory
                                 return sourceStream;
                             }
 
+                            @Override
+                            public Optional<Integer> getParallelism() {
+                                return Optional.ofNullable(parallelism);
+                            }
+
                             @Override
                             public boolean isBounded() {
                                 return boundedness == Boundedness.BOUNDED;
@@ -1039,7 +1053,8 @@ public final class TestValuesTableFactory
                             || acceptedPartitionFilterFields.isEmpty()) {
                         Collection<RowData> values2 = 
convertToRowData(converter);
                         return SourceProvider.of(
-                                new ValuesSource(terminating, boundedness, 
values2, serializer));
+                                new ValuesSource(terminating, boundedness, 
values2, serializer),
+                                parallelism);
                     } else {
                         Map<Map<String, String>, Collection<RowData>> 
partitionValues =
                                 convertToPartitionedRowData(converter);
@@ -1050,7 +1065,7 @@ public final class TestValuesTableFactory
                                         partitionValues,
                                         serializer,
                                         acceptedPartitionFilterFields);
-                        return SourceProvider.of(source);
+                        return SourceProvider.of(source, parallelism);
                     }
                 default:
                     throw new IllegalArgumentException(
@@ -1114,7 +1129,8 @@ public final class TestValuesTableFactory
                     limit,
                     allPartitions,
                     readableMetadata,
-                    projectedMetadataFields);
+                    projectedMetadataFields,
+                    parallelism);
         }
 
         @Override
@@ -1477,7 +1493,8 @@ public final class TestValuesTableFactory
                     limit,
                     allPartitions,
                     readableMetadata,
-                    projectedMetadataFields);
+                    projectedMetadataFields,
+                    null);
         }
 
         @Override
diff --git 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.xml
 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.xml
index 52d21087262..8fe6835213c 100644
--- 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.xml
+++ 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.xml
@@ -727,4 +727,46 @@ Calc(select=[ts, a, b], where=[>(a, 1)], 
changelogMode=[I,UB,UA,D])
 ]]>
     </Resource>
   </TestCase>
+
+  <TestCase name="testSetParallelismForSource">
+    <Resource name="sql">
+      <![CDATA[SELECT * FROM src LEFT JOIN changelog_src on src.id = 
changelog_src.id WHERE src.c > 1]]>
+    </Resource>
+       <Resource name="ast">
+      <![CDATA[
+LogicalProject(id=[$0], b=[$1], c=[$2], id0=[$3], a=[$4])
++- LogicalFilter(condition=[>($2, 1)])
+   +- LogicalJoin(condition=[=($0, $3)], joinType=[left])
+      :- LogicalTableScan(table=[[default_catalog, default_database, src]])
+      +- LogicalTableScan(table=[[default_catalog, default_database, 
changelog_src]])
+]]>
+    </Resource>
+    <Resource name="optimized exec plan">
+      <![CDATA[
+Join(joinType=[LeftOuterJoin], where=[(id = id0)], select=[id, b, c, id0, a], 
leftInputSpec=[NoUniqueKey], rightInputSpec=[JoinKeyContainsUniqueKey])
+:- Exchange(distribution=[hash[id]])
+:  +- Calc(select=[id, b, c], where=[(c > 1)])
+:     +- TableSourceScan(table=[[default_catalog, default_database, src, 
filter=[]]], fields=[id, b, c])
++- Exchange(distribution=[hash[id]])
+   +- ChangelogNormalize(key=[id])
+      +- Exchange(distribution=[hash[id]])
+         +- TableSourceScan(table=[[default_catalog, default_database, 
changelog_src]], fields=[id, a])
+]]>
+       </Resource>
+       <Resource name="transformation">
+      <![CDATA[
+TwoInputTransformation{name='Join(joinType=[LeftOuterJoin], where=[(id = 
id0)], select=[id, b, c, id0, a], leftInputSpec=[NoUniqueKey], 
rightInputSpec=[JoinKeyContainsUniqueKey])', outputType=ROW<`id` INT, `b` 
STRING, `c` INT, `id0` INT, `a` STRING>(org.apache.flink.table.data.RowData, 
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+       +- PartitionTransformation{name='Exchange(distribution=[hash[id]])', 
outputType=ROW<`id` INT, `b` STRING, `c` 
INT>(org.apache.flink.table.data.RowData, 
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+               +- OneInputTransformation{name='Calc(select=[id, b, c], 
where=[(c > 1)])', outputType=ROW<`id` INT, `b` STRING, `c` 
INT>(org.apache.flink.table.data.RowData, 
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+                       +- 
SourceTransformationWrapper{name='ChangeToDefaultParallel', outputType=ROW<`id` 
INT, `b` STRING, `c` INT>(org.apache.flink.table.data.RowData, 
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+                               +- 
LegacySourceTransformation{name='TableSourceScan(table=[[default_catalog, 
default_database, src, filter=[]]], fields=[id, b, c])', outputType=ROW<`id` 
INT, `b` STRING, `c` INT>(org.apache.flink.table.data.RowData, 
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=3}
+       +- PartitionTransformation{name='Exchange(distribution=[hash[id]])', 
outputType=ROW<`id` INT NOT NULL, `a` 
STRING>(org.apache.flink.table.data.RowData, 
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+               +- OneInputTransformation{name='ChangelogNormalize(key=[id])', 
outputType=ROW<`id` INT NOT NULL, `a` 
STRING>(org.apache.flink.table.data.RowData, 
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+                       +- 
PartitionTransformation{name='Exchange(distribution=[hash[id]])', 
outputType=ROW<`id` INT NOT NULL, `a` 
STRING>(org.apache.flink.table.data.RowData, 
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+                               +- PartitionTransformation{name='Partitioner', 
outputType=ROW<`id` INT NOT NULL, `a` 
STRING>(org.apache.flink.table.data.RowData, 
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+                                       +- 
SourceTransformationWrapper{name='ChangeToDefaultParallel', outputType=ROW<`id` 
INT NOT NULL, `a` STRING>(org.apache.flink.table.data.RowData, 
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+                                               +- 
LegacySourceTransformation{name='TableSourceScan(table=[[default_catalog, 
default_database, changelog_src]], fields=[id, a])', outputType=ROW<`id` INT 
NOT NULL, `a` STRING>(org.apache.flink.table.data.RowData, 
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=5}
+]]>
+       </Resource>
+  </TestCase>
 </Root>
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.scala
index 0a31589b61c..be1ae70d3fa 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.scala
@@ -775,4 +775,42 @@ class TableScanTest extends TableTestBase {
           "expression type is CHAR(0) NOT NULL")
       .isInstanceOf[ValidationException]
   }
+
+  @Test
+  def testSetParallelismForSource(): Unit = {
+    val config = TableConfig.getDefault
+    
config.set(ExecutionConfigOptions.TABLE_EXEC_SIMPLIFY_OPERATOR_NAME_ENABLED, 
Boolean.box(false))
+    val util = streamTestUtil(config)
+
+    util.addTable("""
+                    |CREATE TABLE changelog_src (
+                    |  id INT,
+                    |  a STRING,
+                    |  PRIMARY KEY (id) NOT ENFORCED
+                    |) WITH (
+                    |  'connector' = 'values',
+                    |  'bounded' = 'true',
+                    |  'runtime-source' = 'DataStream',
+                    |  'scan.parallelism' = '5',
+                    |  'enable-projection-push-down' = 'false',
+                    |  'changelog-mode' = 'I,UA,D'
+                    |)
+      """.stripMargin)
+    util.addTable("""
+                    |CREATE TABLE src (
+                    |  id INT,
+                    |  b STRING,
+                    |  c INT
+                    |) WITH (
+                    |  'connector' = 'values',
+                    |  'bounded' = 'true',
+                    |  'runtime-source' = 'DataStream',
+                    |  'scan.parallelism' = '3',
+                    |  'enable-projection-push-down' = 'false'
+                    |)
+      """.stripMargin)
+    util.verifyTransformation(
+      "SELECT * FROM src LEFT JOIN changelog_src " +
+        "on src.id = changelog_src.id WHERE src.c > 1")
+  }
 }
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala
index 26b5d3a1709..a2089ee404e 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala
@@ -24,6 +24,7 @@ import org.apache.flink.table.api.{DataTypes, TableException}
 import org.apache.flink.table.api.bridge.scala._
 import org.apache.flink.table.planner.factories.TestValuesTableFactory
 import org.apache.flink.table.planner.runtime.utils.{StreamingTestBase, 
TestData, TestingAppendSink, TestingRetractSink}
+import org.apache.flink.table.planner.runtime.utils.TestData.data1
 import org.apache.flink.table.planner.utils._
 import org.apache.flink.table.runtime.functions.scalar.SourceWatermarkFunction
 import org.apache.flink.table.utils.LegacyRowExtension
@@ -33,6 +34,8 @@ import org.assertj.core.api.Assertions.{assertThat, 
assertThatThrownBy}
 import org.junit.jupiter.api.{BeforeEach, Test}
 import org.junit.jupiter.api.extension.RegisterExtension
 
+import java.util.concurrent.atomic.AtomicInteger
+
 class TableSourceITCase extends StreamingTestBase {
 
   @RegisterExtension private val _: EachCallbackWrapper[LegacyRowExtension] =
@@ -421,4 +424,81 @@ class TableSourceITCase extends StreamingTestBase {
     val expected = Seq("1,Sarah,1", "2,Rob,1", "3,Mike,1")
     assertThat(sink.getAppendResults.sorted).isEqualTo(expected.sorted)
   }
+
+  private def innerTestSetParallelism(provider: String, parallelism: Int, 
index: Int): Unit = {
+    val dataId = TestValuesTableFactory.registerData(data1)
+    val sourceTableName = 
s"test_para_source_${provider.toLowerCase.trim}_$index"
+    val sinkTableName = s"test_para_sink_${provider.toLowerCase.trim}_$index"
+    tEnv.executeSql(s"""
+                       |CREATE TABLE $sourceTableName (
+                       |  the_month INT,
+                       |  area STRING,
+                       |  product INT
+                       |) WITH (
+                       |  'connector' = 'values',
+                       |  'data-id' = '$dataId',
+                       |  'bounded' = 'true',
+                       |  'runtime-source' = '$provider',
+                       |  'scan.parallelism' = '$parallelism',
+                       |  'enable-projection-push-down' = 'false'
+                       |)
+                       |""".stripMargin)
+    tEnv.executeSql(s"""
+                       |CREATE TABLE $sinkTableName (
+                       |  the_month INT,
+                       |  area STRING,
+                       |  product INT
+                       |) WITH (
+                       |  'connector' = 'values',
+                       |  'sink-insert-only' = 'true'
+                       |)
+                       |""".stripMargin)
+    tEnv.executeSql(s"INSERT INTO $sinkTableName SELECT * FROM 
$sourceTableName").await()
+  }
+
+  @Test
+  def testParallelismWithSourceFunction(): Unit = {
+    val negativeParallelism = -1
+    val validParallelism = 3
+    val index = new AtomicInteger(1)
+
+    assertThatThrownBy(
+      () =>
+        innerTestSetParallelism(
+          "SourceFunction",
+          negativeParallelism,
+          index = index.getAndIncrement))
+      .hasMessageContaining(s"Invalid configured parallelism")
+
+    innerTestSetParallelism("SourceFunction", validParallelism, index = 
index.getAndIncrement)
+  }
+
+  @Test
+  def testParallelismWithInputFormat(): Unit = {
+    val negativeParallelism = -1
+    val validParallelism = 3
+    val index = new AtomicInteger(2)
+
+    assertThatThrownBy(
+      () =>
+        innerTestSetParallelism("InputFormat", negativeParallelism, index = 
index.getAndIncrement))
+      .hasMessageContaining(s"Invalid configured parallelism")
+
+    innerTestSetParallelism("InputFormat", validParallelism, index = 
index.getAndIncrement)
+  }
+
+  @Test
+  def testParallelismWithDataStream(): Unit = {
+    val negativeParallelism = -1
+    val validParallelism = 3
+    val index = new AtomicInteger(3)
+
+    assertThatThrownBy(
+      () =>
+        innerTestSetParallelism("DataStream", negativeParallelism, index = 
index.getAndIncrement))
+      .hasMessageContaining(s"Invalid configured parallelism")
+
+    innerTestSetParallelism("DataStream", validParallelism, index = 
index.getAndIncrement)
+  }
+
 }
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
index 1e006f3d94b..e5b418365be 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
@@ -19,6 +19,7 @@ package org.apache.flink.table.planner.utils
 
 import org.apache.flink.FlinkVersion
 import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
+import org.apache.flink.api.dag.Transformation
 import org.apache.flink.api.java.typeutils.{PojoTypeInfo, RowTypeInfo, 
TupleTypeInfo}
 import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
 import org.apache.flink.configuration.BatchExecutionOptions
@@ -87,7 +88,7 @@ import org.junit.jupiter.api.extension.{BeforeEachCallback, 
ExtendWith, Extensio
 import org.junit.jupiter.api.io.TempDir
 import org.junit.platform.commons.support.AnnotationSupport
 
-import java.io.{File, IOException}
+import java.io.{File, IOException, PrintWriter, StringWriter}
 import java.net.URL
 import java.nio.file.{Files, Path, Paths}
 import java.time.Duration
@@ -702,6 +703,20 @@ abstract class TableTestUtilBase(test: TableTestBase, 
isStreamingMode: Boolean)
       withQueryBlockAlias = false)
   }
 
+  /**
+   * Verify the AST (abstract syntax tree), the optimized exec plan and 
tranformation for the given
+   * SELECT query. Note: An exception will be thrown if the given sql can't be 
translated to exec
+   * plan and transformation result is wrong.
+   */
+  def verifyTransformation(query: String): Unit = {
+    doVerifyPlan(
+      query,
+      Array.empty[ExplainDetail],
+      withRowType = false,
+      Array(PlanKind.AST, PlanKind.OPT_EXEC, PlanKind.TRANSFORM),
+      withQueryBlockAlias = false)
+  }
+
   /** Verify the explain result for the given SELECT query. See more about 
[[Table#explain()]]. */
   def verifyExplain(query: String): Unit = 
verifyExplain(getTableEnv.sqlQuery(query))
 
@@ -1040,6 +1055,14 @@ abstract class TableTestUtilBase(test: TableTestBase, 
isStreamingMode: Boolean)
       ""
     }
 
+    // build transformation graph if `expectedPlans` contains TRANSFORM
+    val transformation = if (expectedPlans.contains(PlanKind.TRANSFORM)) {
+      val optimizedNodes = getPlanner.translateToExecNodeGraph(optimizedRels, 
true)
+      System.lineSeparator + 
getTransformations(getPlanner.translateToPlan(optimizedNodes))
+    } else {
+      ""
+    }
+
     // check whether the sql equals to the expected if the `relNodes` are 
translated from sql
     assertSqlEqualsOrExpandFunc()
     // check ast plan
@@ -1058,6 +1081,10 @@ abstract class TableTestUtilBase(test: TableTestBase, 
isStreamingMode: Boolean)
     if (expectedPlans.contains(PlanKind.OPT_EXEC)) {
       assertEqualsOrExpand("optimized exec plan", optimizedExecPlan, expand = 
false)
     }
+    // check transformation graph
+    if (expectedPlans.contains(PlanKind.TRANSFORM)) {
+      assertEqualsOrExpand("transformation", transformation, expand = false)
+    }
   }
 
   private def doVerifyExplain(explainResult: String, extraDetails: 
ExplainDetail*): Unit = {
@@ -1117,6 +1144,25 @@ abstract class TableTestUtilBase(test: TableTestBase, 
isStreamingMode: Boolean)
     replaceEstimatedCost(optimizedPlan)
   }
 
+  private def getTransformations(transformations: 
java.util.List[Transformation[_]]): String = {
+    val stringWriter = new StringWriter()
+    val printWriter = new PrintWriter(stringWriter)
+    transformations.foreach(transformation => getTransformation(printWriter, 
transformation, 0))
+    stringWriter.toString
+  }
+
+  private def getTransformation(
+      printWriter: PrintWriter,
+      transformation: Transformation[_],
+      level: Int): Unit = {
+    if (level == 0) {
+      printWriter.println(transformation.toStringWithoutId)
+    } else {
+      printWriter.println(("\t" * level) + "+- " + 
transformation.toStringWithoutId)
+    }
+    transformation.getInputs.foreach(child => getTransformation(printWriter, 
child, level + 1))
+  }
+
   /** Replace the estimated costs for the given plan, because it may be 
unstable. */
   protected def replaceEstimatedCost(s: String): String = {
     var str = s.replaceAll("\\r\\n", "\n")
@@ -1624,6 +1670,9 @@ object PlanKind extends Enumeration {
 
   /** Optimized Execution Plan */
   val OPT_EXEC: Value = Value("OPT_EXEC")
+
+  /** Transformation */
+  val TRANSFORM: Value = Value("TRANSFORM")
 }
 
 object TableTestUtil {


Reply via email to