This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit dbcd2d7b86fcb7fa7a26e181f1719ea4c6dad828 Author: lincoln.lil <lincoln.8...@gmail.com> AuthorDate: Thu Sep 8 18:08:58 2022 +0800 [FLINK-28569][table-planner] Add projectRowType to RowTypeUtils and deprecate AggCodeGenHelper#projectRowType This closes #20791 --- .../table/planner/typeutils/RowTypeUtils.java | 35 ++++++++++++++++ .../codegen/agg/batch/AggCodeGenHelper.scala | 4 -- .../codegen/agg/batch/HashAggCodeGenerator.scala | 3 +- .../codegen/agg/batch/SortAggCodeGenerator.scala | 3 +- .../codegen/agg/batch/WindowCodeGenerator.scala | 3 +- .../table/planner/typeutils/RowTypeUtilsTest.java | 46 ++++++++++++++++++++++ 6 files changed, 87 insertions(+), 7 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/RowTypeUtils.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/RowTypeUtils.java index ffb9a68a131..4d9879d7b99 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/RowTypeUtils.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/typeutils/RowTypeUtils.java @@ -18,7 +18,13 @@ package org.apache.flink.table.planner.typeutils; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; + +import javax.annotation.Nonnull; + import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -46,4 +52,33 @@ public class RowTypeUtils { } return result; } + + /** + * Returns projected {@link RowType} by given projection indexes over original {@link RowType}. + * Will raise an error when projection index beyond the field count of original rowType. + * + * @param rowType source row type + * @param projection indexes array + * @return projected {@link RowType} + */ + public static RowType projectRowType(@Nonnull RowType rowType, @Nonnull int[] projection) + throws IllegalArgumentException { + final int fieldCnt = rowType.getFieldCount(); + return RowType.of( + Arrays.stream(projection) + .mapToObj( + index -> { + if (index >= fieldCnt) { + throw new IllegalArgumentException( + String.format( + "Invalid projection index: %d of source rowType size: %d", + index, fieldCnt)); + } + return rowType.getTypeAt(index); + }) + .toArray(LogicalType[]::new), + Arrays.stream(projection) + .mapToObj(index -> rowType.getFieldNames().get(index)) + .toArray(String[]::new)); + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala index bb1135fc236..c401e50ea24 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala @@ -93,10 +93,6 @@ object AggCodeGenHelper { .asInstanceOf[Map[AggregateFunction[_, _], String]] } - def projectRowType(rowType: RowType, mapping: Array[Int]): RowType = { - RowType.of(mapping.map(rowType.getTypeAt), mapping.map(rowType.getFieldNames.get(_))) - } - /** Add agg handler to class member and open it. */ private[flink] def addAggsHandler( aggsHandler: GeneratedAggsHandleFunction, diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala index 550a93df3fb..c768ffd5f70 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala @@ -25,6 +25,7 @@ import org.apache.flink.table.functions.AggregateFunction import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, ProjectionCodeGenerator} import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList} +import org.apache.flink.table.planner.typeutils.RowTypeUtils import org.apache.flink.table.runtime.generated.GeneratedOperator import org.apache.flink.table.runtime.operators.TableStreamOperator import org.apache.flink.table.runtime.operators.aggregate.BytesHashMapSpillMemorySegmentPool @@ -60,7 +61,7 @@ class HashAggCodeGenerator( private lazy val aggBufferTypes: Array[Array[LogicalType]] = AggCodeGenHelper.getAggBufferTypes(inputType, auxGrouping, aggInfos) - private lazy val groupKeyRowType = AggCodeGenHelper.projectRowType(inputType, grouping) + private lazy val groupKeyRowType = RowTypeUtils.projectRowType(inputType, grouping) private lazy val aggBufferRowType = RowType.of(aggBufferTypes.flatten, aggBufferNames.flatten) def genWithKeys(): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala index 3a183dc4183..02d44d733ed 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala @@ -25,6 +25,7 @@ import org.apache.flink.table.functions.AggregateFunction import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenUtils, ProjectionCodeGenerator} import org.apache.flink.table.planner.codegen.OperatorCodeGenerator.generateCollect import org.apache.flink.table.planner.plan.utils.AggregateInfoList +import org.apache.flink.table.planner.typeutils.RowTypeUtils import org.apache.flink.table.runtime.generated.GeneratedOperator import org.apache.flink.table.runtime.operators.TableStreamOperator import org.apache.flink.table.types.logical.RowType @@ -63,7 +64,7 @@ object SortAggCodeGenerator { val currentKeyTerm = "currentKey" val currentKeyWriterTerm = "currentKeyWriter" - val groupKeyRowType = AggCodeGenHelper.projectRowType(inputType, grouping) + val groupKeyRowType = RowTypeUtils.projectRowType(inputType, grouping) val keyProjectionCode = ProjectionCodeGenerator .generateProjectionExpression( ctx, diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala index ab4a4221cbb..2df6ae64b14 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala @@ -38,6 +38,7 @@ import org.apache.flink.table.planner.expressions.ExpressionBuilder._ import org.apache.flink.table.planner.expressions.converter.ExpressionConverter import org.apache.flink.table.planner.plan.logical.{LogicalWindow, SlidingGroupWindow, TumblingGroupWindow} import org.apache.flink.table.planner.plan.utils.{AggregateInfo, AggregateInfoList, AggregateUtil} +import org.apache.flink.table.planner.typeutils.RowTypeUtils import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTypeFactory import org.apache.flink.table.runtime.groupwindow.NamedWindowProperty import org.apache.flink.table.runtime.operators.window.TimeWindow @@ -82,7 +83,7 @@ abstract class WindowCodeGenerator( AggCodeGenHelper.getAggBufferTypes(inputRowType, auxGrouping, aggInfos) protected lazy val groupKeyRowType: RowType = - AggCodeGenHelper.projectRowType(inputRowType, grouping) + RowTypeUtils.projectRowType(inputRowType, grouping) protected lazy val timestampInternalType: LogicalType = if (inputTimeIsDate) new IntType() else new BigIntType() diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/typeutils/RowTypeUtilsTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/typeutils/RowTypeUtilsTest.java index b0e754a037b..7449d83a903 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/typeutils/RowTypeUtilsTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/typeutils/RowTypeUtilsTest.java @@ -18,7 +18,16 @@ package org.apache.flink.table.planner.typeutils; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.types.logical.BigIntType; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; + +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import java.util.Arrays; @@ -27,6 +36,13 @@ import static org.assertj.core.api.Assertions.assertThat; /** Tests for {@link RowTypeUtils}. */ public class RowTypeUtilsTest { + @Rule public ExpectedException expectedException = ExpectedException.none(); + + private final RowType srcType = + RowType.of( + new LogicalType[] {new IntType(), new VarCharType(), new BigIntType()}, + new String[] {"f0", "f1", "f2"}); + @Test public void testGetUniqueName() { assertThat( @@ -39,4 +55,34 @@ public class RowTypeUtilsTest { Arrays.asList("Alice", "Bob"))) .isEqualTo(Arrays.asList("Bob_0", "Bob_1", "Dave", "Alice_0")); } + + @Test + public void testProjectRowType() { + assertThat(RowTypeUtils.projectRowType(srcType, new int[] {0})) + .isEqualTo(RowType.of(new LogicalType[] {new IntType()}, new String[] {"f0"})); + + assertThat(RowTypeUtils.projectRowType(srcType, new int[] {0, 2})) + .isEqualTo( + RowType.of( + new LogicalType[] {new IntType(), new BigIntType()}, + new String[] {"f0", "f2"})); + + assertThat(RowTypeUtils.projectRowType(srcType, new int[] {0, 1, 2})).isEqualTo(srcType); + } + + @Test + public void testInvalidProjectRowType() { + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Invalid projection index: 3"); + RowTypeUtils.projectRowType(srcType, new int[] {0, 1, 2, 3}); + + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Invalid projection index: 3"); + RowTypeUtils.projectRowType(srcType, new int[] {0, 1, 3}); + + expectedException.expect(ValidationException.class); + expectedException.expectMessage("Field names must be unique. Found duplicates"); + RowTypeUtils.projectRowType(srcType, new int[] {0, 0, 0, 0}); + } }