This is an automated email from the ASF dual-hosted git repository. libenchao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 42b7e74ab20 [FLINK-33263][table-planner] Implement ParallelismProvider for sources in the table planner 42b7e74ab20 is described below commit 42b7e74ab20785289b62f5dd68d566995ba9dcfc Author: SuDewei <sudewei....@outlook.com> AuthorDate: Thu Jan 18 16:05:40 2024 +0800 [FLINK-33263][table-planner] Implement ParallelismProvider for sources in the table planner Close apache/flink#24128 --- .../org/apache/flink/api/dag/Transformation.java | 15 ++ .../streaming/api/graph/StreamGraphGenerator.java | 3 + .../SourceTransformationWrapper.java | 72 ++++++++++ .../exec/common/CommonExecTableSourceScan.java | 154 ++++++++++++++++++--- .../table/planner/delegation/BatchPlanner.scala | 2 +- .../table/planner/delegation/PlannerBase.scala | 3 +- .../table/planner/delegation/StreamPlanner.scala | 2 +- .../planner/factories/TestValuesTableFactory.java | 33 +++-- .../planner/plan/stream/sql/TableScanTest.xml | 42 ++++++ .../planner/plan/stream/sql/TableScanTest.scala | 38 +++++ .../runtime/stream/sql/TableSourceITCase.scala | 80 +++++++++++ .../flink/table/planner/utils/TableTestBase.scala | 51 ++++++- 12 files changed, 463 insertions(+), 32 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java b/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java index a0448697dd1..6256f9624f6 100644 --- a/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java +++ b/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java @@ -19,6 +19,7 @@ package org.apache.flink.api.dag; import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.InvalidTypesException; import org.apache.flink.api.common.operators.ResourceSpec; @@ -602,6 +603,20 @@ public abstract class Transformation<T> { + '}'; } + @VisibleForTesting + public String toStringWithoutId() { + return getClass().getSimpleName() + + "{" + + "name='" + + name + + '\'' + + ", outputType=" + + outputType + + ", parallelism=" + + parallelism + + '}'; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java index 5929a2a5e8e..8e267ff84d6 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java @@ -65,6 +65,7 @@ import org.apache.flink.streaming.api.transformations.ReduceTransformation; import org.apache.flink.streaming.api.transformations.SideOutputTransformation; import org.apache.flink.streaming.api.transformations.SinkTransformation; import org.apache.flink.streaming.api.transformations.SourceTransformation; +import org.apache.flink.streaming.api.transformations.SourceTransformationWrapper; import org.apache.flink.streaming.api.transformations.TimestampsAndWatermarksTransformation; import org.apache.flink.streaming.api.transformations.TwoInputTransformation; import org.apache.flink.streaming.api.transformations.UnionTransformation; @@ -553,6 +554,8 @@ public class StreamGraphGenerator { transformedIds = transformFeedback((FeedbackTransformation<?>) transform); } else if (transform instanceof CoFeedbackTransformation<?>) { transformedIds = transformCoFeedback((CoFeedbackTransformation<?>) transform); + } else if (transform instanceof SourceTransformationWrapper<?>) { + transformedIds = transform(((SourceTransformationWrapper<?>) transform).getInput()); } else { throw new IllegalStateException("Unknown transformation: " + transform); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformationWrapper.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformationWrapper.java new file mode 100644 index 00000000000..d536000fde2 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformationWrapper.java @@ -0,0 +1,72 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package org.apache.flink.streaming.api.transformations; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.streaming.api.graph.TransformationTranslator; + +import org.apache.flink.shaded.guava31.com.google.common.collect.Lists; + +import java.util.Collections; +import java.util.List; + +/** + * This Transformation is a phantom transformation which is used to expose a default parallelism to + * downstream. + * + * <p>It is used only when the parallelism of the source transformation differs from the default + * parallelism, ensuring that the parallelism of downstream operations is not affected. + * + * <p>Moreover, this transformation does not have a corresponding {@link TransformationTranslator}, + * meaning it will not become a node in the StreamGraph. + * + * @param <T> The type of the elements in the input {@code Transformation} + */ +@Internal +public class SourceTransformationWrapper<T> extends Transformation<T> { + + private final Transformation<T> input; + + public SourceTransformationWrapper(Transformation<T> input) { + super( + "ChangeToDefaultParallel", + input.getOutputType(), + ExecutionConfig.PARALLELISM_DEFAULT); + this.input = input; + } + + public Transformation<T> getInput() { + return input; + } + + @Override + public List<Transformation<?>> getTransitivePredecessors() { + List<Transformation<?>> result = Lists.newArrayList(); + result.add(this); + result.addAll(input.getTransitivePredecessors()); + return result; + } + + @Override + public List<Transformation<?>> getInputs() { + return Collections.singletonList(input); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecTableSourceScan.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecTableSourceScan.java index dc69543cd28..be5b46ba973 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecTableSourceScan.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecTableSourceScan.java @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.plan.nodes.exec.common; +import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.eventtime.WatermarkStrategy; import org.apache.flink.api.common.io.InputFormat; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -30,6 +31,13 @@ import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.operators.StreamSource; import org.apache.flink.streaming.api.transformations.LegacySourceTransformation; +import org.apache.flink.streaming.api.transformations.PartitionTransformation; +import org.apache.flink.streaming.api.transformations.SourceTransformationWrapper; +import org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.connector.ChangelogMode; +import org.apache.flink.table.connector.ParallelismProvider; import org.apache.flink.table.connector.ProviderContext; import org.apache.flink.table.connector.source.DataStreamScanProvider; import org.apache.flink.table.connector.source.InputFormatProvider; @@ -48,17 +56,22 @@ import org.apache.flink.table.planner.plan.nodes.exec.MultipleTransformationTran import org.apache.flink.table.planner.plan.nodes.exec.spec.DynamicTableSourceSpec; import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecNode; import org.apache.flink.table.planner.plan.nodes.exec.utils.TransformationMetadata; +import org.apache.flink.table.planner.plan.utils.KeySelectorUtil; import org.apache.flink.table.planner.utils.ShortcutUtils; import org.apache.flink.table.runtime.connector.source.ScanRuntimeProviderContext; +import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.types.RowKind; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; import java.util.Optional; +import static org.apache.flink.runtime.state.KeyGroupRangeAssignment.DEFAULT_LOWER_BOUND_MAX_PARALLELISM; + /** * Base {@link ExecNode} to read data from an external source defined by a {@link ScanTableSource}. */ @@ -96,6 +109,7 @@ public abstract class CommonExecTableSourceScan extends ExecNodeBase<RowData> @Override protected Transformation<RowData> translateToPlanInternal( PlannerBase planner, ExecNodeConfig config) { + final Transformation<RowData> sourceTransform; final StreamExecutionEnvironment env = planner.getExecEnv(); final TransformationMetadata meta = createTransformationMeta(SOURCE_TRANSFORMATION, config); final InternalTypeInfo<RowData> outputTypeInfo = @@ -105,54 +119,149 @@ public abstract class CommonExecTableSourceScan extends ExecNodeBase<RowData> planner.getFlinkContext(), ShortcutUtils.unwrapTypeFactory(planner)); ScanTableSource.ScanRuntimeProvider provider = tableSource.getScanRuntimeProvider(ScanRuntimeProviderContext.INSTANCE); + final int sourceParallelism = deriveSourceParallelism(provider); + final boolean sourceParallelismConfigured = isParallelismConfigured(provider); if (provider instanceof SourceFunctionProvider) { final SourceFunctionProvider sourceFunctionProvider = (SourceFunctionProvider) provider; final SourceFunction<RowData> function = sourceFunctionProvider.createSourceFunction(); - final Transformation<RowData> transformation = + sourceTransform = createSourceFunctionTransformation( env, function, sourceFunctionProvider.isBounded(), meta.getName(), - outputTypeInfo); - return meta.fill(transformation); + outputTypeInfo, + sourceParallelism, + sourceParallelismConfigured); + if (function instanceof ParallelSourceFunction && sourceParallelismConfigured) { + meta.fill(sourceTransform); + return new SourceTransformationWrapper<>(sourceTransform); + } else { + return meta.fill(sourceTransform); + } } else if (provider instanceof InputFormatProvider) { final InputFormat<RowData, ?> inputFormat = ((InputFormatProvider) provider).createInputFormat(); - final Transformation<RowData> transformation = + sourceTransform = createInputFormatTransformation( env, inputFormat, outputTypeInfo, meta.getName()); - return meta.fill(transformation); + meta.fill(sourceTransform); } else if (provider instanceof SourceProvider) { final Source<RowData, ?, ?> source = ((SourceProvider) provider).createSource(); // TODO: Push down watermark strategy to source scan - final Transformation<RowData> transformation = + sourceTransform = env.fromSource( source, WatermarkStrategy.noWatermarks(), meta.getName(), outputTypeInfo) .getTransformation(); - return meta.fill(transformation); + meta.fill(sourceTransform); } else if (provider instanceof DataStreamScanProvider) { - Transformation<RowData> transformation = + sourceTransform = ((DataStreamScanProvider) provider) .produceDataStream(createProviderContext(config), env) .getTransformation(); - meta.fill(transformation); - transformation.setOutputType(outputTypeInfo); - return transformation; + meta.fill(sourceTransform); + sourceTransform.setOutputType(outputTypeInfo); } else if (provider instanceof TransformationScanProvider) { - final Transformation<RowData> transformation = + sourceTransform = ((TransformationScanProvider) provider) .createTransformation(createProviderContext(config)); - meta.fill(transformation); - transformation.setOutputType(outputTypeInfo); - return transformation; + meta.fill(sourceTransform); + sourceTransform.setOutputType(outputTypeInfo); } else { throw new UnsupportedOperationException( provider.getClass().getSimpleName() + " is unsupported now."); } + + if (sourceParallelismConfigured) { + return applySourceTransformationWrapper( + sourceTransform, + planner.getFlinkContext().getClassLoader(), + outputTypeInfo, + config, + tableSource.getChangelogMode(), + sourceParallelism); + } else { + return sourceTransform; + } + } + + private boolean isParallelismConfigured(ScanTableSource.ScanRuntimeProvider runtimeProvider) { + return runtimeProvider instanceof ParallelismProvider + && ((ParallelismProvider) runtimeProvider).getParallelism().isPresent(); + } + + private int deriveSourceParallelism(ScanTableSource.ScanRuntimeProvider runtimeProvider) { + if (isParallelismConfigured(runtimeProvider)) { + int sourceParallelism = ((ParallelismProvider) runtimeProvider).getParallelism().get(); + if (sourceParallelism <= 0) { + throw new TableException( + String.format( + "Invalid configured parallelism %s for table '%s'.", + sourceParallelism, + tableSourceSpec + .getContextResolvedTable() + .getIdentifier() + .asSummaryString())); + } + return sourceParallelism; + } else { + return ExecutionConfig.PARALLELISM_DEFAULT; + } + } + + protected RowType getPhysicalRowType(ResolvedSchema schema) { + return (RowType) schema.toPhysicalRowDataType().getLogicalType(); + } + + protected int[] getPrimaryKeyIndices(RowType sourceRowType, ResolvedSchema schema) { + return schema.getPrimaryKey() + .map(k -> k.getColumns().stream().mapToInt(sourceRowType::getFieldIndex).toArray()) + .orElse(new int[0]); + } + + private Transformation<RowData> applySourceTransformationWrapper( + Transformation<RowData> sourceTransform, + ClassLoader classLoader, + InternalTypeInfo<RowData> outputTypeInfo, + ExecNodeConfig config, + ChangelogMode changelogMode, + int sourceParallelism) { + sourceTransform.setParallelism(sourceParallelism, true); + Transformation<RowData> sourceTransformationWrapper = + new SourceTransformationWrapper<>(sourceTransform); + + if (!changelogMode.containsOnly(RowKind.INSERT)) { + final ResolvedSchema schema = + tableSourceSpec.getContextResolvedTable().getResolvedSchema(); + final RowType physicalRowType = getPhysicalRowType(schema); + final int[] primaryKeys = getPrimaryKeyIndices(physicalRowType, schema); + final boolean hasPk = primaryKeys.length > 0; + if (!hasPk) { + throw new TableException( + String.format( + "Configured parallelism %s for upsert table '%s' while can not find primary key field. " + + "This is a bug, please file an issue.", + sourceParallelism, + tableSourceSpec + .getContextResolvedTable() + .getIdentifier() + .asSummaryString())); + } + final RowDataKeySelector selector = + KeySelectorUtil.getRowDataSelector(classLoader, primaryKeys, outputTypeInfo); + final KeyGroupStreamPartitioner<RowData, RowData> partitioner = + new KeyGroupStreamPartitioner<>(selector, DEFAULT_LOWER_BOUND_MAX_PARALLELISM); + Transformation<RowData> partitionedTransform = + new PartitionTransformation<>(sourceTransformationWrapper, partitioner); + createTransformationMeta("partitioner", "Partitioner", "Partitioner", config) + .fill(partitionedTransform); + return partitionedTransform; + } else { + return sourceTransformationWrapper; + } } private ProviderContext createProviderContext(ExecNodeConfig config) { @@ -178,17 +287,22 @@ public abstract class CommonExecTableSourceScan extends ExecNodeBase<RowData> SourceFunction<RowData> function, boolean isBounded, String operatorName, - TypeInformation<RowData> outputTypeInfo) { + TypeInformation<RowData> outputTypeInfo, + int sourceParallelism, + boolean sourceParallelismConfigured) { env.clean(function); final int parallelism; - boolean parallelismConfigured = false; if (function instanceof ParallelSourceFunction) { - parallelism = env.getParallelism(); + if (sourceParallelismConfigured) { + parallelism = sourceParallelism; + } else { + parallelism = env.getParallelism(); + } } else { parallelism = 1; - parallelismConfigured = true; + sourceParallelismConfigured = true; } final Boundedness boundedness; @@ -205,7 +319,7 @@ public abstract class CommonExecTableSourceScan extends ExecNodeBase<RowData> outputTypeInfo, parallelism, boundedness, - parallelismConfigured); + sourceParallelismConfigured); } /** diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala index bb4c1b75a28..cea10f7bb81 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala @@ -84,7 +84,7 @@ class BatchPlanner( processors } - override protected def translateToPlan(execGraph: ExecNodeGraph): util.List[Transformation[_]] = { + override def translateToPlan(execGraph: ExecNodeGraph): util.List[Transformation[_]] = { beforeTranslation() val planner = createDummyPlanner() diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala index 45788e6278e..b36edaa21d7 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala @@ -367,7 +367,8 @@ abstract class PlannerBase( * @return * The [[Transformation]] DAG that corresponds to the node DAG. */ - protected def translateToPlan(execGraph: ExecNodeGraph): util.List[Transformation[_]] + @VisibleForTesting + def translateToPlan(execGraph: ExecNodeGraph): util.List[Transformation[_]] def addExtraTransformation(transformation: Transformation[_]): Unit = { if (!extraTransformations.contains(transformation)) { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/StreamPlanner.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/StreamPlanner.scala index fb32326f117..894a37c8cf9 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/StreamPlanner.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/StreamPlanner.scala @@ -78,7 +78,7 @@ class StreamPlanner( override protected def getExecNodeGraphProcessors: Seq[ExecNodeGraphProcessor] = Seq() - override protected def translateToPlan(execGraph: ExecNodeGraph): util.List[Transformation[_]] = { + override def translateToPlan(execGraph: ExecNodeGraph): util.List[Transformation[_]] = { beforeTranslation() val planner = createDummyPlanner() val transformations = execGraph.getRootNodes.map { diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java index 3dbf4d5b9c0..db64847b75b 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java @@ -466,6 +466,8 @@ public final class TestValuesTableFactory private static final ConfigOption<String> SINK_CHANGELOG_MODE_ENFORCED = ConfigOptions.key("sink-changelog-mode-enforced").stringType().noDefaultValue(); + private static final ConfigOption<Integer> SOURCE_PARALLELISM = FactoryUtil.SOURCE_PARALLELISM; + private static final ConfigOption<Integer> SINK_PARALLELISM = FactoryUtil.SINK_PARALLELISM; @Override @@ -497,6 +499,7 @@ public final class TestValuesTableFactory int lookupThreshold = helper.getOptions().get(LOOKUP_THRESHOLD); int sleepAfterElements = helper.getOptions().get(SOURCE_SLEEP_AFTER_ELEMENTS); long sleepTimeMillis = helper.getOptions().get(SOURCE_SLEEP_TIME).toMillis(); + Integer parallelism = helper.getOptions().get(SOURCE_PARALLELISM); DefaultLookupCache cache = null; if (helper.getOptions().get(CACHE_TYPE).equals(LookupOptions.LookupCacheType.PARTIAL)) { cache = DefaultLookupCache.fromConfig(helper.getOptions()); @@ -571,7 +574,8 @@ public final class TestValuesTableFactory Long.MAX_VALUE, partitions, readableMetadata, - null); + null, + parallelism); } if (disableLookup) { @@ -746,6 +750,7 @@ public final class TestValuesTableFactory SOURCE_NUM_ELEMENT_TO_SKIP, SOURCE_SLEEP_AFTER_ELEMENTS, SOURCE_SLEEP_TIME, + SOURCE_PARALLELISM, INTERNAL_DATA, CACHE_TYPE, PARTIAL_CACHE_EXPIRE_AFTER_ACCESS, @@ -916,6 +921,7 @@ public final class TestValuesTableFactory private @Nullable int[] groupingSet; private List<AggregateExpression> aggregateExpressions; private List<String> acceptedPartitionFilterFields; + private final Integer parallelism; private TestValuesScanTableSourceWithoutProjectionPushDown( DataType producedDataType, @@ -934,7 +940,8 @@ public final class TestValuesTableFactory long limit, List<Map<String, String>> allPartitions, Map<String, DataType> readableMetadata, - @Nullable int[] projectedMetadataFields) { + @Nullable int[] projectedMetadataFields, + @Nullable Integer parallelism) { this.producedDataType = producedDataType; this.changelogMode = changelogMode; this.boundedness = boundedness; @@ -954,6 +961,7 @@ public final class TestValuesTableFactory this.projectedMetadataFields = projectedMetadataFields; this.groupingSet = null; this.aggregateExpressions = Collections.emptyList(); + this.parallelism = parallelism; } @Override @@ -987,7 +995,7 @@ public final class TestValuesTableFactory sourceFunction = new FromElementsFunction<>(serializer, values); } return SourceFunctionProvider.of( - sourceFunction, boundedness == Boundedness.BOUNDED); + sourceFunction, boundedness == Boundedness.BOUNDED, parallelism); } catch (IOException e) { throw new TableException("Fail to init source function", e); } @@ -999,7 +1007,8 @@ public final class TestValuesTableFactory terminating == TerminatingLogic.FINITE, "Values Source doesn't support infinite InputFormat."); Collection<RowData> values = convertToRowData(converter); - return InputFormatProvider.of(new CollectionInputFormat<>(values, serializer)); + return InputFormatProvider.of( + new CollectionInputFormat<>(values, serializer), parallelism); case "DataStream": checkArgument( !failingSource, @@ -1024,6 +1033,11 @@ public final class TestValuesTableFactory return sourceStream; } + @Override + public Optional<Integer> getParallelism() { + return Optional.ofNullable(parallelism); + } + @Override public boolean isBounded() { return boundedness == Boundedness.BOUNDED; @@ -1039,7 +1053,8 @@ public final class TestValuesTableFactory || acceptedPartitionFilterFields.isEmpty()) { Collection<RowData> values2 = convertToRowData(converter); return SourceProvider.of( - new ValuesSource(terminating, boundedness, values2, serializer)); + new ValuesSource(terminating, boundedness, values2, serializer), + parallelism); } else { Map<Map<String, String>, Collection<RowData>> partitionValues = convertToPartitionedRowData(converter); @@ -1050,7 +1065,7 @@ public final class TestValuesTableFactory partitionValues, serializer, acceptedPartitionFilterFields); - return SourceProvider.of(source); + return SourceProvider.of(source, parallelism); } default: throw new IllegalArgumentException( @@ -1114,7 +1129,8 @@ public final class TestValuesTableFactory limit, allPartitions, readableMetadata, - projectedMetadataFields); + projectedMetadataFields, + parallelism); } @Override @@ -1477,7 +1493,8 @@ public final class TestValuesTableFactory limit, allPartitions, readableMetadata, - projectedMetadataFields); + projectedMetadataFields, + null); } @Override diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.xml index 52d21087262..8fe6835213c 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.xml @@ -727,4 +727,46 @@ Calc(select=[ts, a, b], where=[>(a, 1)], changelogMode=[I,UB,UA,D]) ]]> </Resource> </TestCase> + + <TestCase name="testSetParallelismForSource"> + <Resource name="sql"> + <![CDATA[SELECT * FROM src LEFT JOIN changelog_src on src.id = changelog_src.id WHERE src.c > 1]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(id=[$0], b=[$1], c=[$2], id0=[$3], a=[$4]) ++- LogicalFilter(condition=[>($2, 1)]) + +- LogicalJoin(condition=[=($0, $3)], joinType=[left]) + :- LogicalTableScan(table=[[default_catalog, default_database, src]]) + +- LogicalTableScan(table=[[default_catalog, default_database, changelog_src]]) +]]> + </Resource> + <Resource name="optimized exec plan"> + <![CDATA[ +Join(joinType=[LeftOuterJoin], where=[(id = id0)], select=[id, b, c, id0, a], leftInputSpec=[NoUniqueKey], rightInputSpec=[JoinKeyContainsUniqueKey]) +:- Exchange(distribution=[hash[id]]) +: +- Calc(select=[id, b, c], where=[(c > 1)]) +: +- TableSourceScan(table=[[default_catalog, default_database, src, filter=[]]], fields=[id, b, c]) ++- Exchange(distribution=[hash[id]]) + +- ChangelogNormalize(key=[id]) + +- Exchange(distribution=[hash[id]]) + +- TableSourceScan(table=[[default_catalog, default_database, changelog_src]], fields=[id, a]) +]]> + </Resource> + <Resource name="transformation"> + <![CDATA[ +TwoInputTransformation{name='Join(joinType=[LeftOuterJoin], where=[(id = id0)], select=[id, b, c, id0, a], leftInputSpec=[NoUniqueKey], rightInputSpec=[JoinKeyContainsUniqueKey])', outputType=ROW<`id` INT, `b` STRING, `c` INT, `id0` INT, `a` STRING>(org.apache.flink.table.data.RowData, org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1} + +- PartitionTransformation{name='Exchange(distribution=[hash[id]])', outputType=ROW<`id` INT, `b` STRING, `c` INT>(org.apache.flink.table.data.RowData, org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1} + +- OneInputTransformation{name='Calc(select=[id, b, c], where=[(c > 1)])', outputType=ROW<`id` INT, `b` STRING, `c` INT>(org.apache.flink.table.data.RowData, org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1} + +- SourceTransformationWrapper{name='ChangeToDefaultParallel', outputType=ROW<`id` INT, `b` STRING, `c` INT>(org.apache.flink.table.data.RowData, org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1} + +- LegacySourceTransformation{name='TableSourceScan(table=[[default_catalog, default_database, src, filter=[]]], fields=[id, b, c])', outputType=ROW<`id` INT, `b` STRING, `c` INT>(org.apache.flink.table.data.RowData, org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=3} + +- PartitionTransformation{name='Exchange(distribution=[hash[id]])', outputType=ROW<`id` INT NOT NULL, `a` STRING>(org.apache.flink.table.data.RowData, org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1} + +- OneInputTransformation{name='ChangelogNormalize(key=[id])', outputType=ROW<`id` INT NOT NULL, `a` STRING>(org.apache.flink.table.data.RowData, org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1} + +- PartitionTransformation{name='Exchange(distribution=[hash[id]])', outputType=ROW<`id` INT NOT NULL, `a` STRING>(org.apache.flink.table.data.RowData, org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1} + +- PartitionTransformation{name='Partitioner', outputType=ROW<`id` INT NOT NULL, `a` STRING>(org.apache.flink.table.data.RowData, org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1} + +- SourceTransformationWrapper{name='ChangeToDefaultParallel', outputType=ROW<`id` INT NOT NULL, `a` STRING>(org.apache.flink.table.data.RowData, org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1} + +- LegacySourceTransformation{name='TableSourceScan(table=[[default_catalog, default_database, changelog_src]], fields=[id, a])', outputType=ROW<`id` INT NOT NULL, `a` STRING>(org.apache.flink.table.data.RowData, org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=5} +]]> + </Resource> + </TestCase> </Root> diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.scala index 0a31589b61c..be1ae70d3fa 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.scala @@ -775,4 +775,42 @@ class TableScanTest extends TableTestBase { "expression type is CHAR(0) NOT NULL") .isInstanceOf[ValidationException] } + + @Test + def testSetParallelismForSource(): Unit = { + val config = TableConfig.getDefault + config.set(ExecutionConfigOptions.TABLE_EXEC_SIMPLIFY_OPERATOR_NAME_ENABLED, Boolean.box(false)) + val util = streamTestUtil(config) + + util.addTable(""" + |CREATE TABLE changelog_src ( + | id INT, + | a STRING, + | PRIMARY KEY (id) NOT ENFORCED + |) WITH ( + | 'connector' = 'values', + | 'bounded' = 'true', + | 'runtime-source' = 'DataStream', + | 'scan.parallelism' = '5', + | 'enable-projection-push-down' = 'false', + | 'changelog-mode' = 'I,UA,D' + |) + """.stripMargin) + util.addTable(""" + |CREATE TABLE src ( + | id INT, + | b STRING, + | c INT + |) WITH ( + | 'connector' = 'values', + | 'bounded' = 'true', + | 'runtime-source' = 'DataStream', + | 'scan.parallelism' = '3', + | 'enable-projection-push-down' = 'false' + |) + """.stripMargin) + util.verifyTransformation( + "SELECT * FROM src LEFT JOIN changelog_src " + + "on src.id = changelog_src.id WHERE src.c > 1") + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala index 26b5d3a1709..a2089ee404e 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala @@ -24,6 +24,7 @@ import org.apache.flink.table.api.{DataTypes, TableException} import org.apache.flink.table.api.bridge.scala._ import org.apache.flink.table.planner.factories.TestValuesTableFactory import org.apache.flink.table.planner.runtime.utils.{StreamingTestBase, TestData, TestingAppendSink, TestingRetractSink} +import org.apache.flink.table.planner.runtime.utils.TestData.data1 import org.apache.flink.table.planner.utils._ import org.apache.flink.table.runtime.functions.scalar.SourceWatermarkFunction import org.apache.flink.table.utils.LegacyRowExtension @@ -33,6 +34,8 @@ import org.assertj.core.api.Assertions.{assertThat, assertThatThrownBy} import org.junit.jupiter.api.{BeforeEach, Test} import org.junit.jupiter.api.extension.RegisterExtension +import java.util.concurrent.atomic.AtomicInteger + class TableSourceITCase extends StreamingTestBase { @RegisterExtension private val _: EachCallbackWrapper[LegacyRowExtension] = @@ -421,4 +424,81 @@ class TableSourceITCase extends StreamingTestBase { val expected = Seq("1,Sarah,1", "2,Rob,1", "3,Mike,1") assertThat(sink.getAppendResults.sorted).isEqualTo(expected.sorted) } + + private def innerTestSetParallelism(provider: String, parallelism: Int, index: Int): Unit = { + val dataId = TestValuesTableFactory.registerData(data1) + val sourceTableName = s"test_para_source_${provider.toLowerCase.trim}_$index" + val sinkTableName = s"test_para_sink_${provider.toLowerCase.trim}_$index" + tEnv.executeSql(s""" + |CREATE TABLE $sourceTableName ( + | the_month INT, + | area STRING, + | product INT + |) WITH ( + | 'connector' = 'values', + | 'data-id' = '$dataId', + | 'bounded' = 'true', + | 'runtime-source' = '$provider', + | 'scan.parallelism' = '$parallelism', + | 'enable-projection-push-down' = 'false' + |) + |""".stripMargin) + tEnv.executeSql(s""" + |CREATE TABLE $sinkTableName ( + | the_month INT, + | area STRING, + | product INT + |) WITH ( + | 'connector' = 'values', + | 'sink-insert-only' = 'true' + |) + |""".stripMargin) + tEnv.executeSql(s"INSERT INTO $sinkTableName SELECT * FROM $sourceTableName").await() + } + + @Test + def testParallelismWithSourceFunction(): Unit = { + val negativeParallelism = -1 + val validParallelism = 3 + val index = new AtomicInteger(1) + + assertThatThrownBy( + () => + innerTestSetParallelism( + "SourceFunction", + negativeParallelism, + index = index.getAndIncrement)) + .hasMessageContaining(s"Invalid configured parallelism") + + innerTestSetParallelism("SourceFunction", validParallelism, index = index.getAndIncrement) + } + + @Test + def testParallelismWithInputFormat(): Unit = { + val negativeParallelism = -1 + val validParallelism = 3 + val index = new AtomicInteger(2) + + assertThatThrownBy( + () => + innerTestSetParallelism("InputFormat", negativeParallelism, index = index.getAndIncrement)) + .hasMessageContaining(s"Invalid configured parallelism") + + innerTestSetParallelism("InputFormat", validParallelism, index = index.getAndIncrement) + } + + @Test + def testParallelismWithDataStream(): Unit = { + val negativeParallelism = -1 + val validParallelism = 3 + val index = new AtomicInteger(3) + + assertThatThrownBy( + () => + innerTestSetParallelism("DataStream", negativeParallelism, index = index.getAndIncrement)) + .hasMessageContaining(s"Invalid configured parallelism") + + innerTestSetParallelism("DataStream", validParallelism, index = index.getAndIncrement) + } + } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala index 1e006f3d94b..e5b418365be 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala @@ -19,6 +19,7 @@ package org.apache.flink.table.planner.utils import org.apache.flink.FlinkVersion import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} +import org.apache.flink.api.dag.Transformation import org.apache.flink.api.java.typeutils.{PojoTypeInfo, RowTypeInfo, TupleTypeInfo} import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo import org.apache.flink.configuration.BatchExecutionOptions @@ -87,7 +88,7 @@ import org.junit.jupiter.api.extension.{BeforeEachCallback, ExtendWith, Extensio import org.junit.jupiter.api.io.TempDir import org.junit.platform.commons.support.AnnotationSupport -import java.io.{File, IOException} +import java.io.{File, IOException, PrintWriter, StringWriter} import java.net.URL import java.nio.file.{Files, Path, Paths} import java.time.Duration @@ -702,6 +703,20 @@ abstract class TableTestUtilBase(test: TableTestBase, isStreamingMode: Boolean) withQueryBlockAlias = false) } + /** + * Verify the AST (abstract syntax tree), the optimized exec plan and tranformation for the given + * SELECT query. Note: An exception will be thrown if the given sql can't be translated to exec + * plan and transformation result is wrong. + */ + def verifyTransformation(query: String): Unit = { + doVerifyPlan( + query, + Array.empty[ExplainDetail], + withRowType = false, + Array(PlanKind.AST, PlanKind.OPT_EXEC, PlanKind.TRANSFORM), + withQueryBlockAlias = false) + } + /** Verify the explain result for the given SELECT query. See more about [[Table#explain()]]. */ def verifyExplain(query: String): Unit = verifyExplain(getTableEnv.sqlQuery(query)) @@ -1040,6 +1055,14 @@ abstract class TableTestUtilBase(test: TableTestBase, isStreamingMode: Boolean) "" } + // build transformation graph if `expectedPlans` contains TRANSFORM + val transformation = if (expectedPlans.contains(PlanKind.TRANSFORM)) { + val optimizedNodes = getPlanner.translateToExecNodeGraph(optimizedRels, true) + System.lineSeparator + getTransformations(getPlanner.translateToPlan(optimizedNodes)) + } else { + "" + } + // check whether the sql equals to the expected if the `relNodes` are translated from sql assertSqlEqualsOrExpandFunc() // check ast plan @@ -1058,6 +1081,10 @@ abstract class TableTestUtilBase(test: TableTestBase, isStreamingMode: Boolean) if (expectedPlans.contains(PlanKind.OPT_EXEC)) { assertEqualsOrExpand("optimized exec plan", optimizedExecPlan, expand = false) } + // check transformation graph + if (expectedPlans.contains(PlanKind.TRANSFORM)) { + assertEqualsOrExpand("transformation", transformation, expand = false) + } } private def doVerifyExplain(explainResult: String, extraDetails: ExplainDetail*): Unit = { @@ -1117,6 +1144,25 @@ abstract class TableTestUtilBase(test: TableTestBase, isStreamingMode: Boolean) replaceEstimatedCost(optimizedPlan) } + private def getTransformations(transformations: java.util.List[Transformation[_]]): String = { + val stringWriter = new StringWriter() + val printWriter = new PrintWriter(stringWriter) + transformations.foreach(transformation => getTransformation(printWriter, transformation, 0)) + stringWriter.toString + } + + private def getTransformation( + printWriter: PrintWriter, + transformation: Transformation[_], + level: Int): Unit = { + if (level == 0) { + printWriter.println(transformation.toStringWithoutId) + } else { + printWriter.println(("\t" * level) + "+- " + transformation.toStringWithoutId) + } + transformation.getInputs.foreach(child => getTransformation(printWriter, child, level + 1)) + } + /** Replace the estimated costs for the given plan, because it may be unstable. */ protected def replaceEstimatedCost(s: String): String = { var str = s.replaceAll("\\r\\n", "\n") @@ -1624,6 +1670,9 @@ object PlanKind extends Enumeration { /** Optimized Execution Plan */ val OPT_EXEC: Value = Value("OPT_EXEC") + + /** Transformation */ + val TRANSFORM: Value = Value("TRANSFORM") } object TableTestUtil {