This is an automated email from the ASF dual-hosted git repository.

libenchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 45f966e8c3c [FLINK-33541][table-planner] RAND and RAND_INTEGER should 
return type nullable if the arguments are nullable
45f966e8c3c is described below

commit 45f966e8c3c5e903b3843391874f7d2478122d8c
Author: xuyang <xyzhong...@163.com>
AuthorDate: Thu Nov 23 14:05:07 2023 +0800

    [FLINK-33541][table-planner] RAND and RAND_INTEGER should return type 
nullable if the arguments are nullable
    
    Close apache/flink#23779
---
 .../functions/BuiltInFunctionDefinitions.java      |  4 +-
 .../functions/sql/FlinkSqlOperatorTable.java       |  8 +-
 .../table/planner/codegen/calls/RandCallGen.scala  |  5 +-
 .../planner/functions/BuiltInFunctionTestBase.java | 89 ++++++++++++++------
 .../planner/functions/RandFunctionITCase.java      | 94 ++++++++++++++++++++++
 .../planner/expressions/ScalarFunctionsTest.scala  | 45 +++++++++++
 .../expressions/utils/ExpressionTestBase.scala     | 14 +++-
 7 files changed, 226 insertions(+), 33 deletions(-)

diff --git 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
index 89ac66d18f6..b65afdc4284 100644
--- 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
+++ 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
@@ -1644,7 +1644,7 @@ public final class BuiltInFunctionDefinitions {
                     .kind(SCALAR)
                     .notDeterministic()
                     .inputTypeStrategy(or(NO_ARGS, 
sequence(logical(LogicalTypeRoot.INTEGER))))
-                    .outputTypeStrategy(explicit(DataTypes.DOUBLE().notNull()))
+                    
.outputTypeStrategy(nullableIfArgs(explicit(DataTypes.DOUBLE())))
                     .build();
 
     public static final BuiltInFunctionDefinition RAND_INTEGER =
@@ -1658,7 +1658,7 @@ public final class BuiltInFunctionDefinitions {
                                     sequence(
                                             logical(LogicalTypeRoot.INTEGER),
                                             logical(LogicalTypeRoot.INTEGER))))
-                    .outputTypeStrategy(explicit(INT().notNull()))
+                    .outputTypeStrategy(nullableIfArgs(explicit(INT())))
                     .build();
 
     public static final BuiltInFunctionDefinition BIN =
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
index 9769613cd2d..1a98081dd79 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
@@ -940,7 +940,9 @@ public class FlinkSqlOperatorTable extends 
ReflectiveSqlOperatorTable {
             new SqlFunction(
                     "RAND",
                     SqlKind.OTHER_FUNCTION,
-                    ReturnTypes.DOUBLE,
+                    ReturnTypes.cascade(
+                            ReturnTypes.explicit(SqlTypeName.DOUBLE),
+                            SqlTypeTransforms.TO_NULLABLE),
                     null,
                     OperandTypes.or(
                             new SqlSingleOperandTypeChecker[] {
@@ -958,7 +960,9 @@ public class FlinkSqlOperatorTable extends 
ReflectiveSqlOperatorTable {
             new SqlFunction(
                     "RAND_INTEGER",
                     SqlKind.OTHER_FUNCTION,
-                    ReturnTypes.INTEGER,
+                    ReturnTypes.cascade(
+                            ReturnTypes.explicit(SqlTypeName.INTEGER),
+                            SqlTypeTransforms.TO_NULLABLE),
                     null,
                     OperandTypes.or(
                             new SqlSingleOperandTypeChecker[] {
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/RandCallGen.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/RandCallGen.scala
index 1a7b4950586..e322e1854ed 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/RandCallGen.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/RandCallGen.scala
@@ -39,11 +39,12 @@ class RandCallGen(isRandInteger: Boolean, hasSeed: Boolean) 
extends CallGenerato
     }
 
     if (isRandInteger) {
-      generateCallIfArgsNotNull(ctx, new IntType(), operands) {
+      generateCallIfArgsNotNull(ctx, new IntType(returnType.isNullable), 
operands) {
         terms => s"$randField.nextInt(${terms.last})"
       }
     } else {
-      generateCallIfArgsNotNull(ctx, new DoubleType(), operands)(_ => 
s"$randField.nextDouble()")
+      generateCallIfArgsNotNull(ctx, new DoubleType(returnType.isNullable), 
operands)(
+        _ => s"$randField.nextDouble()")
     }
   }
 
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInFunctionTestBase.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInFunctionTestBase.java
index 566a0106083..96e49e96f47 100644
--- 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInFunctionTestBase.java
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInFunctionTestBase.java
@@ -57,8 +57,10 @@ import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import java.util.stream.Stream;
 
+import static java.util.Collections.emptyList;
 import static java.util.Collections.singletonList;
 import static org.apache.flink.core.testutils.FlinkAssertions.anyCauseMatches;
+import static org.apache.flink.table.api.Expressions.row;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.Assertions.catchThrowable;
@@ -137,7 +139,7 @@ abstract class BuiltInFunctionTestBase {
 
         private final List<TestItem> testItems;
 
-        private Object[] fieldData;
+        private @Nullable Object[] fieldData;
 
         private @Nullable AbstractDataType<?>[] fieldDataTypes;
 
@@ -182,6 +184,11 @@ abstract class BuiltInFunctionTestBase {
                     singletonList(expression), singletonList(result), 
singletonList(dataType));
         }
 
+        TestSetSpec testTableApiResult(Expression expression, 
AbstractDataType<?> dataType) {
+            return testTableApiResult(
+                    singletonList(expression), emptyList(), 
singletonList(dataType));
+        }
+
         TestSetSpec testTableApiResult(
                 List<Expression> expression,
                 List<Object> result,
@@ -213,6 +220,10 @@ abstract class BuiltInFunctionTestBase {
             return testSqlResult(expression, singletonList(result), 
singletonList(dataType));
         }
 
+        TestSetSpec testSqlResult(String expression, AbstractDataType<?> 
dataType) {
+            return testSqlResult(expression, emptyList(), 
singletonList(dataType));
+        }
+
         TestSetSpec testSqlResult(
                 String expression, List<Object> result, 
List<AbstractDataType<?>> dataType) {
             testItems.add(new SqlResultTestItem(expression, result, dataType));
@@ -308,8 +319,13 @@ abstract class BuiltInFunctionTestBase {
                         functions.forEach(
                                 f -> 
env.createTemporarySystemFunction(f.getSimpleName(), f));
 
+                        Preconditions.checkArgument(
+                                !(fieldData == null && fieldDataTypes != null),
+                                "The field data type is set but the field data 
is not.");
                         final Table inputTable;
-                        if (fieldDataTypes == null) {
+                        if (fieldData == null) {
+                            inputTable = null;
+                        } else if (fieldDataTypes == null) {
                             inputTable = env.fromValues(Row.of(fieldData));
                         } else {
                             final DataTypes.UnresolvedField[] fields =
@@ -334,7 +350,12 @@ abstract class BuiltInFunctionTestBase {
     }
 
     private interface TestItem {
-        void test(TableEnvironmentInternal env, Table inputTable) throws 
Exception;
+        /**
+         * @param env The table environment for test to execute.
+         * @param inputTable The input table of this test that contains input 
data and data type. If
+         *     it is null, the test is not dependent on the input data.
+         */
+        void test(TableEnvironmentInternal env, @Nullable Table inputTable) 
throws Exception;
     }
 
     private abstract static class ResultTestItem<T> implements TestItem {
@@ -348,10 +369,11 @@ abstract class BuiltInFunctionTestBase {
             this.dataTypes = dataTypes;
         }
 
-        abstract Table query(TableEnvironment env, Table inputTable);
+        abstract Table query(TableEnvironment env, @Nullable Table inputTable);
 
         @Override
-        public void test(TableEnvironmentInternal env, Table inputTable) 
throws Exception {
+        public void test(TableEnvironmentInternal env, @Nullable Table 
inputTable)
+                throws Exception {
             final Table resultTable = this.query(env, inputTable);
 
             final List<DataType> expectedDataTypes =
@@ -365,20 +387,27 @@ abstract class BuiltInFunctionTestBase {
                 assertThat(iterator).as("No more rows 
expected.").isExhausted();
 
                 for (int i = 0; i < row.getArity(); i++) {
-                    assertThat(
-                                    result.getResolvedSchema()
-                                            .getColumnDataTypes()
-                                            .get(i)
-                                            .getLogicalType())
-                            .as("Logical type for spec [%d] of test [%s] 
doesn't match.", i, this)
-                            
.isEqualTo(expectedDataTypes.get(i).getLogicalType());
-
-                    assertThat(Row.of(row.getField(i)))
-                            .as("Result for spec [%d] of test [%s] doesn't 
match.", i, this)
-                            .isEqualTo(
-                                    // Use Row.equals() to enable equality for 
complex structure,
-                                    // i.e. byte[]
-                                    Row.of(this.results.get(i)));
+                    if (!expectedDataTypes.isEmpty()) {
+                        assertThat(
+                                        result.getResolvedSchema()
+                                                .getColumnDataTypes()
+                                                .get(i)
+                                                .getLogicalType())
+                                .as(
+                                        "Logical type for spec [%d] of test 
[%s] doesn't match.",
+                                        i, this)
+                                
.isEqualTo(expectedDataTypes.get(i).getLogicalType());
+                    }
+
+                    if (!this.results.isEmpty()) {
+                        assertThat(Row.of(row.getField(i)))
+                                .as("Result for spec [%d] of test [%s] doesn't 
match.", i, this)
+                                .isEqualTo(
+                                        // Use Row.equals() to enable equality 
for complex
+                                        // structure,
+                                        // i.e. byte[]
+                                        Row.of(this.results.get(i)));
+                    }
                 }
             }
         }
@@ -402,7 +431,7 @@ abstract class BuiltInFunctionTestBase {
             this.expectedDuringValidation = expectedDuringValidation;
         }
 
-        abstract Table query(TableEnvironment env, Table inputTable);
+        abstract Table query(TableEnvironment env, @Nullable Table inputTable);
 
         Consumer<? super Throwable> errorMatcher() {
             if (errorClass != null && errorMessage != null) {
@@ -415,7 +444,7 @@ abstract class BuiltInFunctionTestBase {
         }
 
         @Override
-        public void test(TableEnvironmentInternal env, Table inputTable) {
+        public void test(TableEnvironmentInternal env, @Nullable Table 
inputTable) {
             AtomicReference<TableResult> tableResult = new AtomicReference<>();
 
             Throwable t =
@@ -447,8 +476,14 @@ abstract class BuiltInFunctionTestBase {
         }
 
         @Override
-        Table query(TableEnvironment env, Table inputTable) {
-            return inputTable.select(expression.toArray(new Expression[] {}));
+        Table query(TableEnvironment env, @Nullable Table inputTable) {
+            if (inputTable != null) {
+                return inputTable.select(expression.toArray(new Expression[] 
{}));
+            } else {
+                // use a mock collection table with row "0" to avoid pruning 
the project
+                // node with expression by PruneEmptyRules.PROJECT_INSTANCE
+                return env.fromValues(row(0)).select(expression.toArray(new 
Expression[] {}));
+            }
         }
 
         @Override
@@ -519,8 +554,12 @@ abstract class BuiltInFunctionTestBase {
         }
 
         @Override
-        Table query(TableEnvironment env, Table inputTable) {
-            return env.sqlQuery("SELECT " + expression + " FROM " + 
inputTable);
+        Table query(TableEnvironment env, @Nullable Table inputTable) {
+            if (inputTable != null) {
+                return env.sqlQuery("SELECT " + expression + " FROM " + 
inputTable);
+            } else {
+                return env.sqlQuery("SELECT " + expression);
+            }
         }
 
         @Override
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/RandFunctionITCase.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/RandFunctionITCase.java
new file mode 100644
index 00000000000..2c13cc55379
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/RandFunctionITCase.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions;
+
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
+
+import java.util.stream.Stream;
+
+import static org.apache.flink.table.api.Expressions.$;
+import static org.apache.flink.table.api.Expressions.rand;
+import static org.apache.flink.table.api.Expressions.randInteger;
+
+/**
+ * Test for {@link 
org.apache.flink.table.functions.BuiltInFunctionDefinitions#RAND} and {@link
+ * org.apache.flink.table.functions.BuiltInFunctionDefinitions#RAND_INTEGER} 
and their return type.
+ */
+public class RandFunctionITCase extends BuiltInFunctionTestBase {
+
+    @Override
+    Stream<TestSetSpec> getTestSetSpecs() {
+        return Stream.of(
+                // RAND()
+                TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND)
+                        .testSqlResult("RAND()", DataTypes.DOUBLE().notNull())
+                        .testTableApiResult(rand(), 
DataTypes.DOUBLE().notNull()),
+                // RAND(INT)
+                TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND)
+                        .onFieldsWithData(10)
+                        .andDataTypes(DataTypes.INT())
+                        .testSqlResult("RAND(f0)", 0.7304302967434272, 
DataTypes.DOUBLE())
+                        .testTableApiResult(rand($("f0")), 0.7304302967434272, 
DataTypes.DOUBLE()),
+                // RAND(INT NOT NULL)
+                TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND)
+                        .onFieldsWithData(10)
+                        .andDataTypes(DataTypes.INT().notNull())
+                        .testSqlResult("RAND(f0)", 0.7304302967434272, 
DataTypes.DOUBLE().notNull())
+                        .testTableApiResult(
+                                rand($("f0")), 0.7304302967434272, 
DataTypes.DOUBLE().notNull()),
+                // RAND_INTEGER(INT)
+                TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND)
+                        .onFieldsWithData(5)
+                        .andDataTypes(DataTypes.INT())
+                        .testSqlResult("RAND_INTEGER(f0)", DataTypes.INT())
+                        .testTableApiResult(randInteger($("f0")), 
DataTypes.INT()),
+                // RAND_INTEGER(INT NOT NULL)
+                TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND)
+                        .onFieldsWithData(5)
+                        .andDataTypes(DataTypes.INT().notNull())
+                        .testSqlResult("RAND_INTEGER(f0)", 
DataTypes.INT().notNull())
+                        .testTableApiResult(randInteger($("f0")), 
DataTypes.INT().notNull()),
+                // RAND_INTEGER(INT, INT)
+                TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND)
+                        .onFieldsWithData(5, 10)
+                        .andDataTypes(DataTypes.INT(), DataTypes.INT())
+                        .testSqlResult("RAND_INTEGER(f0, f1)", 7, 
DataTypes.INT())
+                        .testTableApiResult(randInteger($("f0"), $("f1")), 7, 
DataTypes.INT()),
+                // RAND_INTEGER(INT, INT NOT NULL)
+                TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND)
+                        .onFieldsWithData(5, 10)
+                        .andDataTypes(DataTypes.INT(), 
DataTypes.INT().notNull())
+                        .testSqlResult("RAND_INTEGER(f0, f1)", 7, 
DataTypes.INT())
+                        .testTableApiResult(randInteger($("f0"), $("f1")), 7, 
DataTypes.INT()),
+                // RAND_INTEGER(INT NOT NULL, INT)
+                TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND)
+                        .onFieldsWithData(5, 10)
+                        .andDataTypes(DataTypes.INT().notNull(), 
DataTypes.INT())
+                        .testSqlResult("RAND_INTEGER(f0, f1)", 7, 
DataTypes.INT())
+                        .testTableApiResult(randInteger($("f0"), $("f1")), 7, 
DataTypes.INT()),
+                // RAND_INTEGER(INT NOT NULL, INT NOT NULL)
+                TestSetSpec.forFunction(BuiltInFunctionDefinitions.RAND)
+                        .onFieldsWithData(5, 10)
+                        .andDataTypes(DataTypes.INT().notNull(), 
DataTypes.INT().notNull())
+                        .testSqlResult("RAND_INTEGER(f0, f1)", 7, 
DataTypes.INT().notNull())
+                        .testTableApiResult(
+                                randInteger($("f0"), $("f1")), 7, 
DataTypes.INT().notNull()));
+    }
+}
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala
index 983b272fb3b..d226ed93be9 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala
@@ -2614,9 +2614,13 @@ class ScalarFunctionsTest extends ScalarTypesTestBase {
   def testIf(): Unit = {
     // test IF(BOOL, INT, BIGINT), will do implicit type coercion.
     testSqlApi("IF(f7 > 5, f14, f4)", "44")
+
     // test input with null
     testSqlApi("IF(f7 < 5, cast(null as int), f4)", "NULL")
 
+    // test IF(BOOLEAN, INT, INT NOT NULL)
+    testSqlApi("IF(f7 < 5, cast(null as int), 0)", "NULL")
+
     // f0 is a STRING, cast(f0 as double) should never be ran
     testSqlApi("IF(1 = 1, f6, cast(f0 as double))", "4.6")
 
@@ -2661,6 +2665,47 @@ class ScalarFunctionsTest extends ScalarTypesTestBase {
     //      "1996-11-10 06:55:44.333")
   }
 
+  @Test
+  def testRandAndIf(): Unit = {
+    // test RAND
+    testSqlApi("IF(1 = 1, RAND(), cast(1.0 as double))")
+
+    // test RAND(INT) and IF(BOOLEAN, DOUBLE, DOUBLE NOT NULL)
+    testSqlApi("IF(1 = 1, RAND(f7), cast(1.0 as double))", "0.731057369148862")
+
+    // test RAND(NULL) and IF(BOOLEAN, NULL, DOUBLE NOT NULL)
+    testSqlApi("IF(1 = 1, RAND(cast(null as int)), cast(1.0 as double))", 
"NULL")
+
+    // test RAND(INT NOT NULL) and IF(BOOLEAN, DOUBLE NOT NULL, DOUBLE NOT 
NULL)
+    testSqlApi("IF(1 = 1, RAND(1), cast(1.0 as double))", "0.7308781907032909")
+
+    // test RAND_INTEGER
+
+    // test RAND_INTEGER(INT) and IF(BOOLEAN, INT, INT NOT NULL)
+    testSqlApi("IF(1 = 1, RAND_INTEGER(f7), 1)")
+
+    // test RAND_INTEGER(INT NOT NULL) and IF(BOOLEAN, INT NOT NULL, INT NOT 
NULL)
+    testSqlApi("IF(1 = 1, RAND_INTEGER(10), 1)")
+
+    // test RAND_INTEGER(NULL) and IF(BOOLEAN, NULL, INT NOT NULL)
+    testSqlApi("IF(1 = 1, RAND_INTEGER(cast(null as int)), 1)")
+
+    // test RAND_INTEGER(INT, INT NOT NULL) and IF(BOOLEAN, INT, INT NOT NULL)
+    testSqlApi("IF(1 = 1, RAND_INTEGER(f7, 10), 1)", "4")
+
+    // test RAND_INTEGER(INT NOT NULL, INT) and IF(BOOLEAN, INT, INT NOT NULL)
+    testSqlApi("IF(1 = 1, RAND_INTEGER(3, f7), 1)", "2")
+
+    // test RAND_INTEGER(NULL, INT NOT NULL) and IF(BOOLEAN, NULL, INT NOT 
NULL)
+    testSqlApi("IF(1 = 1, RAND_INTEGER(cast(null as int), 10), 1)", "NULL")
+
+    // test RAND_INTEGER(INT NOT NULL, INT) and IF(BOOLEAN, NULL, INT NOT NULL)
+    testSqlApi("IF(1 = 1, RAND_INTEGER(3, cast(null as int)), 1)", "NULL")
+
+    // test RAND_INTEGER(INT NOT NULL, INT NOT NULL) and IF(BOOLEAN, INT NOT 
NULL, INT NOT NULL)
+    testSqlApi("IF(1 = 1, RAND_INTEGER(3, 99), 1)", "50")
+  }
+
   @Test
   def testIfDecimal(): Unit = {
     // test DECIMAL, DECIMAL
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/ExpressionTestBase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/ExpressionTestBase.scala
index dee5eda6111..0dac08610b4 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/ExpressionTestBase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/ExpressionTestBase.scala
@@ -57,6 +57,8 @@ import org.assertj.core.api.ThrowableAssert.ThrowingCallable
 import org.junit.jupiter.api.{AfterEach, BeforeEach}
 import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue}
 
+import javax.annotation.Nullable
+
 import java.util.Collections
 
 import scala.collection.JavaConverters._
@@ -187,6 +189,10 @@ abstract class ExpressionTestBase(isStreaming: Boolean = 
true) {
     addSqlTestExpr(sqlExpr, expected, validExprs)
   }
 
+  def testSqlApi(sqlExpr: String): Unit = {
+    addSqlTestExpr(sqlExpr, null, validExprs)
+  }
+
   def testExpectedAllApisException(
       expr: Expression,
       sqlExpr: String,
@@ -282,7 +288,7 @@ abstract class ExpressionTestBase(isStreaming: Boolean = 
true) {
 
   private def addSqlTestExpr(
       sqlExpr: String,
-      expected: String,
+      @Nullable expected: String,
       exprsContainer: mutable.ArrayBuffer[_],
       exceptionClass: Class[_ <: Throwable] = null): Unit = {
     // create RelNode from SQL expression
@@ -307,7 +313,7 @@ abstract class ExpressionTestBase(isStreaming: Boolean = 
true) {
 
   private def addTestExpr(
       relNode: RelNode,
-      expected: String,
+      @Nullable expected: String,
       summaryString: String,
       exceptionClass: Class[_ <: Throwable],
       exprs: mutable.ArrayBuffer[_]): Unit = {
@@ -344,6 +350,10 @@ abstract class ExpressionTestBase(isStreaming: Boolean = 
true) {
       .zip(result)
       .foreach {
         case ((originalExpr, optimizedExpr, expected), actual) =>
+          if (expected == null) {
+            // no need to check the result
+            return
+          }
           val original = if (originalExpr == null) "" else s"for: 
[$originalExpr]"
           assertEquals(
             expected,

Reply via email to