This is an automated email from the ASF dual-hosted git repository. libenchao pushed a commit to branch release-1.18 in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.18 by this push: new d6081815000 [FLINK-33541][table-planner] RAND and RAND_INTEGER should return type nullable if the arguments are nullable d6081815000 is described below commit d60818150005661006a71e4155fc605d7543362b Author: xuyang <xyzhong...@163.com> AuthorDate: Thu Nov 23 14:05:07 2023 +0800 [FLINK-33541][table-planner] RAND and RAND_INTEGER should return type nullable if the arguments are nullable Close apache/flink#23779 --- .../functions/BuiltInFunctionDefinitions.java | 4 +- .../functions/sql/FlinkSqlOperatorTable.java | 8 +- .../table/planner/codegen/calls/RandCallGen.scala | 5 +- .../planner/functions/BuiltInFunctionTestBase.java | 89 ++++++++++++++------ .../planner/functions/RandFunctionITCase.java | 94 ++++++++++++++++++++++ .../planner/expressions/ScalarFunctionsTest.scala | 45 +++++++++++ .../expressions/utils/ExpressionTestBase.scala | 14 +++- 7 files changed, 226 insertions(+), 33 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index e653d1d6463..82197822dc5 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -1576,7 +1576,7 @@ public final class BuiltInFunctionDefinitions { .kind(SCALAR) .notDeterministic() .inputTypeStrategy(or(NO_ARGS, sequence(logical(LogicalTypeRoot.INTEGER)))) - .outputTypeStrategy(explicit(DataTypes.DOUBLE().notNull())) + .outputTypeStrategy(nullableIfArgs(explicit(DataTypes.DOUBLE()))) .build(); public static final BuiltInFunctionDefinition RAND_INTEGER = @@ -1590,7 +1590,7 @@ public final class BuiltInFunctionDefinitions { sequence( logical(LogicalTypeRoot.INTEGER), logical(LogicalTypeRoot.INTEGER)))) - .outputTypeStrategy(explicit(INT().notNull())) + .outputTypeStrategy(nullableIfArgs(explicit(INT()))) .build(); public static final BuiltInFunctionDefinition BIN = diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java index 9769613cd2d..1a98081dd79 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java @@ -940,7 +940,9 @@ public class FlinkSqlOperatorTable extends ReflectiveSqlOperatorTable { new SqlFunction( "RAND", SqlKind.OTHER_FUNCTION, - ReturnTypes.DOUBLE, + ReturnTypes.cascade( + ReturnTypes.explicit(SqlTypeName.DOUBLE), + SqlTypeTransforms.TO_NULLABLE), null, OperandTypes.or( new SqlSingleOperandTypeChecker[] { @@ -958,7 +960,9 @@ public class FlinkSqlOperatorTable extends ReflectiveSqlOperatorTable { new SqlFunction( "RAND_INTEGER", SqlKind.OTHER_FUNCTION, - ReturnTypes.INTEGER, + ReturnTypes.cascade( + ReturnTypes.explicit(SqlTypeName.INTEGER), + SqlTypeTransforms.TO_NULLABLE), null, OperandTypes.or( new SqlSingleOperandTypeChecker[] { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/RandCallGen.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/RandCallGen.scala index 1a7b4950586..e322e1854ed 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/RandCallGen.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/RandCallGen.scala @@ -39,11 +39,12 @@ class RandCallGen(isRandInteger: Boolean, hasSeed: Boolean) extends CallGenerato } if (isRandInteger) { - generateCallIfArgsNotNull(ctx, new IntType(), operands) { + generateCallIfArgsNotNull(ctx, new IntType(returnType.isNullable), operands) { terms => s"$randField.nextInt(${terms.last})" } } else { - generateCallIfArgsNotNull(ctx, new DoubleType(), operands)(_ => s"$randField.nextDouble()") + generateCallIfArgsNotNull(ctx, new DoubleType(returnType.isNullable), operands)( + _ => s"$randField.nextDouble()") } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInFunctionTestBase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInFunctionTestBase.java index da355199084..d95e4fd4e53 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInFunctionTestBase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInFunctionTestBase.java @@ -55,8 +55,10 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; +import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; import static org.apache.flink.core.testutils.FlinkAssertions.anyCauseMatches; +import static org.apache.flink.table.api.Expressions.row; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.catchThrowable; @@ -133,7 +135,7 @@ abstract class BuiltInFunctionTestBase { private final List<TestItem> testItems; - private Object[] fieldData; + private @Nullable Object[] fieldData; private @Nullable AbstractDataType<?>[] fieldDataTypes; @@ -178,6 +180,11 @@ abstract class BuiltInFunctionTestBase { singletonList(expression), singletonList(result), singletonList(dataType)); } + TestSetSpec testTableApiResult(Expression expression, AbstractDataType<?> dataType) { + return testTableApiResult( + singletonList(expression), emptyList(), singletonList(dataType)); + } + TestSetSpec testTableApiResult( List<Expression> expression, List<Object> result, @@ -209,6 +216,10 @@ abstract class BuiltInFunctionTestBase { return testSqlResult(expression, singletonList(result), singletonList(dataType)); } + TestSetSpec testSqlResult(String expression, AbstractDataType<?> dataType) { + return testSqlResult(expression, emptyList(), singletonList(dataType)); + } + TestSetSpec testSqlResult( String expression, List<Object> result, List<AbstractDataType<?>> dataType) { testItems.add(new SqlResultTestItem(expression, result, dataType)); @@ -303,8 +314,13 @@ abstract class BuiltInFunctionTestBase { functions.forEach( f -> env.createTemporarySystemFunction(f.getSimpleName(), f)); + Preconditions.checkArgument( + !(fieldData == null && fieldDataTypes != null), + "The field data type is set but the field data is not."); final Table inputTable; - if (fieldDataTypes == null) { + if (fieldData == null) { + inputTable = null; + } else if (fieldDataTypes == null) { inputTable = env.fromValues(Row.of(fieldData)); } else { final DataTypes.UnresolvedField[] fields = @@ -329,7 +345,12 @@ abstract class BuiltInFunctionTestBase { } private interface TestItem { - void test(TableEnvironmentInternal env, Table inputTable) throws Exception; + /** + * @param env The table environment for test to execute. + * @param inputTable The input table of this test that contains input data and data type. If + * it is null, the test is not dependent on the input data. + */ + void test(TableEnvironmentInternal env, @Nullable Table inputTable) throws Exception; } private abstract static class ResultTestItem<T> implements TestItem { @@ -343,10 +364,11 @@ abstract class BuiltInFunctionTestBase { this.dataTypes = dataTypes; } - abstract Table query(TableEnvironment env, Table inputTable); + abstract Table query(TableEnvironment env, @Nullable Table inputTable); @Override - public void test(TableEnvironmentInternal env, Table inputTable) throws Exception { + public void test(TableEnvironmentInternal env, @Nullable Table inputTable) + throws Exception { final Table resultTable = this.query(env, inputTable); final List<DataType> expectedDataTypes = @@ -360,20 +382,27 @@ abstract class BuiltInFunctionTestBase { assertThat(iterator).as("No more rows expected.").isExhausted(); for (int i = 0; i < row.getArity(); i++) { - assertThat( - result.getResolvedSchema() - .getColumnDataTypes() - .get(i) - .getLogicalType()) - .as("Logical type for spec [%d] of test [%s] doesn't match.", i, this) - .isEqualTo(expectedDataTypes.get(i).getLogicalType()); - - assertThat(Row.of(row.getField(i))) - .as("Result for spec [%d] of test [%s] doesn't match.", i, this) - .isEqualTo( - // Use Row.equals() to enable equality for complex structure, - // i.e. byte[] - Row.of(this.results.get(i))); + if (!expectedDataTypes.isEmpty()) { + assertThat( + result.getResolvedSchema() + .getColumnDataTypes() + .get(i) + .getLogicalType()) + .as( + "Logical type for spec [%d] of test [%s] doesn't match.", + i, this) + .isEqualTo(expectedDataTypes.get(i).getLogicalType()); + } + + if (!this.results.isEmpty()) { + assertThat(Row.of(row.getField(i))) + .as("Result for spec [%d] of test [%s] doesn't match.", i, this) + .isEqualTo( + // Use Row.equals() to enable equality for complex + // structure, + // i.e. byte[] + Row.of(this.results.get(i))); + } } } } @@ -397,7 +426,7 @@ abstract class BuiltInFunctionTestBase { this.expectedDuringValidation = expectedDuringValidation; } - abstract Table query(TableEnvironment env, Table inputTable); + abstract Table query(TableEnvironment env, @Nullable Table inputTable); Consumer<? super Throwable> errorMatcher() { if (errorClass != null && errorMessage != null) { @@ -410,7 +439,7 @@ abstract class BuiltInFunctionTestBase { } @Override - public void test(TableEnvironmentInternal env, Table inputTable) { + public void test(TableEnvironmentInternal env, @Nullable Table inputTable) { AtomicReference<TableResult> tableResult = new AtomicReference<>(); Throwable t = @@ -442,8 +471,14 @@ abstract class BuiltInFunctionTestBase { } @Override - Table query(TableEnvironment env, Table inputTable) { - return inputTable.select(expression.toArray(new Expression[] {})); + Table query(TableEnvironment env, @Nullable Table inputTable) { + if (inputTable != null) { + return inputTable.select(expression.toArray(new Expression[] {})); + } else { + // use a mock collection table with row "0" to avoid pruning the project + // node with expression by PruneEmptyRules.PROJECT_INSTANCE + return env.fromValues(row(0)).select(expression.toArray(new Expression[] {})); + } } @Override @@ -484,8 +519,12 @@ abstract class BuiltInFunctionTestBase { } @Override - Table query(TableEnvironment env, Table inputTable) { - return env.sqlQuery("SELECT " + expression + " FROM " + inputTable); + Table query(TableEnvironment env, @Nullable Table inputTable) { + if (inputTable != null) { + return env.sqlQuery("SELECT " + expression + " FROM " + inputTable); + } else { + return env.sqlQuery("SELECT " + expression); + } } @Override diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/RandFunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/RandFunctionITCase.java new file mode 100644 index 00000000000..2c13cc55379 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/RandFunctionITCase.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions; + +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; + +import java.util.stream.Stream; + +import static org.apache.flink.table.api.Expressions.$; +import static org.apache.flink.table.api.Expressions.rand; +import static org.apache.flink.table.api.Expressions.randInteger; + +/** + * Test for {@link org.apache.flink.table.functions.BuiltInFunctionDefinitions#RAND} and {@link + * org.apache.flink.table.functions.BuiltInFunctionDefinitions#RAND_INTEGER} and their return type. + */ +public class RandFunctionITCase extends BuiltInFunctionTestBase { + + @Override + Stream<TestSetSpec> getTestSetSpecs() { + return Stream.of( + // RAND() + TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND) + .testSqlResult("RAND()", DataTypes.DOUBLE().notNull()) + .testTableApiResult(rand(), DataTypes.DOUBLE().notNull()), + // RAND(INT) + TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND) + .onFieldsWithData(10) + .andDataTypes(DataTypes.INT()) + .testSqlResult("RAND(f0)", 0.7304302967434272, DataTypes.DOUBLE()) + .testTableApiResult(rand($("f0")), 0.7304302967434272, DataTypes.DOUBLE()), + // RAND(INT NOT NULL) + TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND) + .onFieldsWithData(10) + .andDataTypes(DataTypes.INT().notNull()) + .testSqlResult("RAND(f0)", 0.7304302967434272, DataTypes.DOUBLE().notNull()) + .testTableApiResult( + rand($("f0")), 0.7304302967434272, DataTypes.DOUBLE().notNull()), + // RAND_INTEGER(INT) + TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND) + .onFieldsWithData(5) + .andDataTypes(DataTypes.INT()) + .testSqlResult("RAND_INTEGER(f0)", DataTypes.INT()) + .testTableApiResult(randInteger($("f0")), DataTypes.INT()), + // RAND_INTEGER(INT NOT NULL) + TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND) + .onFieldsWithData(5) + .andDataTypes(DataTypes.INT().notNull()) + .testSqlResult("RAND_INTEGER(f0)", DataTypes.INT().notNull()) + .testTableApiResult(randInteger($("f0")), DataTypes.INT().notNull()), + // RAND_INTEGER(INT, INT) + TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND) + .onFieldsWithData(5, 10) + .andDataTypes(DataTypes.INT(), DataTypes.INT()) + .testSqlResult("RAND_INTEGER(f0, f1)", 7, DataTypes.INT()) + .testTableApiResult(randInteger($("f0"), $("f1")), 7, DataTypes.INT()), + // RAND_INTEGER(INT, INT NOT NULL) + TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND) + .onFieldsWithData(5, 10) + .andDataTypes(DataTypes.INT(), DataTypes.INT().notNull()) + .testSqlResult("RAND_INTEGER(f0, f1)", 7, DataTypes.INT()) + .testTableApiResult(randInteger($("f0"), $("f1")), 7, DataTypes.INT()), + // RAND_INTEGER(INT NOT NULL, INT) + TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND) + .onFieldsWithData(5, 10) + .andDataTypes(DataTypes.INT().notNull(), DataTypes.INT()) + .testSqlResult("RAND_INTEGER(f0, f1)", 7, DataTypes.INT()) + .testTableApiResult(randInteger($("f0"), $("f1")), 7, DataTypes.INT()), + // RAND_INTEGER(INT NOT NULL, INT NOT NULL) + TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND) + .onFieldsWithData(5, 10) + .andDataTypes(DataTypes.INT().notNull(), DataTypes.INT().notNull()) + .testSqlResult("RAND_INTEGER(f0, f1)", 7, DataTypes.INT().notNull()) + .testTableApiResult( + randInteger($("f0"), $("f1")), 7, DataTypes.INT().notNull())); + } +} diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala index 96a51a2c7e8..2ed21fd376e 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala @@ -2614,9 +2614,13 @@ class ScalarFunctionsTest extends ScalarTypesTestBase { def testIf(): Unit = { // test IF(BOOL, INT, BIGINT), will do implicit type coercion. testSqlApi("IF(f7 > 5, f14, f4)", "44") + // test input with null testSqlApi("IF(f7 < 5, cast(null as int), f4)", "NULL") + // test IF(BOOLEAN, INT, INT NOT NULL) + testSqlApi("IF(f7 < 5, cast(null as int), 0)", "NULL") + // f0 is a STRING, cast(f0 as double) should never be ran testSqlApi("IF(1 = 1, f6, cast(f0 as double))", "4.6") @@ -2661,6 +2665,47 @@ class ScalarFunctionsTest extends ScalarTypesTestBase { // "1996-11-10 06:55:44.333") } + @Test + def testRandAndIf(): Unit = { + // test RAND + testSqlApi("IF(1 = 1, RAND(), cast(1.0 as double))") + + // test RAND(INT) and IF(BOOLEAN, DOUBLE, DOUBLE NOT NULL) + testSqlApi("IF(1 = 1, RAND(f7), cast(1.0 as double))", "0.731057369148862") + + // test RAND(NULL) and IF(BOOLEAN, NULL, DOUBLE NOT NULL) + testSqlApi("IF(1 = 1, RAND(cast(null as int)), cast(1.0 as double))", "NULL") + + // test RAND(INT NOT NULL) and IF(BOOLEAN, DOUBLE NOT NULL, DOUBLE NOT NULL) + testSqlApi("IF(1 = 1, RAND(1), cast(1.0 as double))", "0.7308781907032909") + + // test RAND_INTEGER + + // test RAND_INTEGER(INT) and IF(BOOLEAN, INT, INT NOT NULL) + testSqlApi("IF(1 = 1, RAND_INTEGER(f7), 1)") + + // test RAND_INTEGER(INT NOT NULL) and IF(BOOLEAN, INT NOT NULL, INT NOT NULL) + testSqlApi("IF(1 = 1, RAND_INTEGER(10), 1)") + + // test RAND_INTEGER(NULL) and IF(BOOLEAN, NULL, INT NOT NULL) + testSqlApi("IF(1 = 1, RAND_INTEGER(cast(null as int)), 1)") + + // test RAND_INTEGER(INT, INT NOT NULL) and IF(BOOLEAN, INT, INT NOT NULL) + testSqlApi("IF(1 = 1, RAND_INTEGER(f7, 10), 1)", "4") + + // test RAND_INTEGER(INT NOT NULL, INT) and IF(BOOLEAN, INT, INT NOT NULL) + testSqlApi("IF(1 = 1, RAND_INTEGER(3, f7), 1)", "2") + + // test RAND_INTEGER(NULL, INT NOT NULL) and IF(BOOLEAN, NULL, INT NOT NULL) + testSqlApi("IF(1 = 1, RAND_INTEGER(cast(null as int), 10), 1)", "NULL") + + // test RAND_INTEGER(INT NOT NULL, INT) and IF(BOOLEAN, NULL, INT NOT NULL) + testSqlApi("IF(1 = 1, RAND_INTEGER(3, cast(null as int)), 1)", "NULL") + + // test RAND_INTEGER(INT NOT NULL, INT NOT NULL) and IF(BOOLEAN, INT NOT NULL, INT NOT NULL) + testSqlApi("IF(1 = 1, RAND_INTEGER(3, 99), 1)", "50") + } + @Test def testIfDecimal(): Unit = { // test DECIMAL, DECIMAL diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/ExpressionTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/ExpressionTestBase.scala index 0b8b5ffb359..ef867d45454 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/ExpressionTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/ExpressionTestBase.scala @@ -57,6 +57,8 @@ import org.junit.{After, Before, Rule} import org.junit.Assert.{assertEquals, assertTrue, fail} import org.junit.rules.ExpectedException +import javax.annotation.Nullable + import java.util.Collections import scala.collection.JavaConverters._ @@ -198,6 +200,10 @@ abstract class ExpressionTestBase(isStreaming: Boolean = true) { addSqlTestExpr(sqlExpr, expected, validExprs) } + def testSqlApi(sqlExpr: String): Unit = { + addSqlTestExpr(sqlExpr, null, validExprs) + } + def testExpectedAllApisException( expr: Expression, sqlExpr: String, @@ -293,7 +299,7 @@ abstract class ExpressionTestBase(isStreaming: Boolean = true) { private def addSqlTestExpr( sqlExpr: String, - expected: String, + @Nullable expected: String, exprsContainer: mutable.ArrayBuffer[_], exceptionClass: Class[_ <: Throwable] = null): Unit = { // create RelNode from SQL expression @@ -318,7 +324,7 @@ abstract class ExpressionTestBase(isStreaming: Boolean = true) { private def addTestExpr( relNode: RelNode, - expected: String, + @Nullable expected: String, summaryString: String, exceptionClass: Class[_ <: Throwable], exprs: mutable.ArrayBuffer[_]): Unit = { @@ -355,6 +361,10 @@ abstract class ExpressionTestBase(isStreaming: Boolean = true) { .zip(result) .foreach { case ((originalExpr, optimizedExpr, expected), actual) => + if (expected == null) { + // no need to check the result + return + } val original = if (originalExpr == null) "" else s"for: [$originalExpr]" assertEquals( s"Wrong result $original optimized to: [$optimizedExpr]",