This is an automated email from the ASF dual-hosted git repository.

yuxia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 6663c8b29f6 [FLINK-30966][table-planner] Optimize type inference for 
Character type and fix result type casting in IfCallGen (#21927)
6663c8b29f6 is described below

commit 6663c8b29f672b961d34baed314940621f3754ca
Author: Shuiqiang Chen <acqua....@gmail.com>
AuthorDate: Mon Jun 5 17:46:50 2023 +0800

    [FLINK-30966][table-planner] Optimize type inference for Character type and 
fix result type casting in IfCallGen (#21927)
---
 .../functions/sql/FlinkSqlOperatorTable.java       |  2 +-
 .../table/planner/plan/type/FlinkReturnTypes.java  | 39 +++++++++++--
 .../type/NumericOrDefaultReturnTypeInference.java  | 67 ----------------------
 .../table/planner/codegen/calls/IfCallGen.scala    |  4 +-
 .../planner/runtime/stream/sql/CalcITCase.scala    | 25 ++++++++
 5 files changed, 62 insertions(+), 75 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
index a09e809cd9c..cb2ff52c6e8 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
@@ -661,7 +661,7 @@ public class FlinkSqlOperatorTable extends 
ReflectiveSqlOperatorTable {
             new SqlFunction(
                     "IF",
                     SqlKind.OTHER_FUNCTION,
-                    FlinkReturnTypes.NUMERIC_FROM_ARG1_DEFAULT1_NULLABLE,
+                    FlinkReturnTypes.IF_NULLABLE,
                     null,
                     OperandTypes.or(
                             OperandTypes.and(
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/type/FlinkReturnTypes.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/type/FlinkReturnTypes.java
index 1ff51b4a946..6253580c662 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/type/FlinkReturnTypes.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/type/FlinkReturnTypes.java
@@ -31,8 +31,12 @@ import org.apache.calcite.sql.type.ReturnTypes;
 import org.apache.calcite.sql.type.SqlReturnTypeInference;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.type.SqlTypeTransforms;
+import org.apache.calcite.sql.type.SqlTypeUtil;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
 import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.List;
 
 /** Type inference in Flink. */
 public class FlinkReturnTypes {
@@ -115,11 +119,36 @@ public class FlinkReturnTypes {
     public static final SqlReturnTypeInference ROUND_FUNCTION_NULLABLE =
             ReturnTypes.cascade(ROUND_FUNCTION, SqlTypeTransforms.TO_NULLABLE);
 
-    public static final SqlReturnTypeInference NUMERIC_FROM_ARG1_DEFAULT1 =
-            new NumericOrDefaultReturnTypeInference(1, 1);
-
-    public static final SqlReturnTypeInference 
NUMERIC_FROM_ARG1_DEFAULT1_NULLABLE =
-            ReturnTypes.cascade(NUMERIC_FROM_ARG1_DEFAULT1, 
SqlTypeTransforms.TO_NULLABLE);
+    /**
+     * Determine the return type of IF functions with arguments that has the 
least restrictive (eg:
+     * numeric, character, binary). The return type is the type of the 
argument with the largest
+     * range. We start to consider the arguments from the first one. If one of 
the arguments is not
+     * of the type that has the least restrictive (eg: numeric, character, 
binary), we return the
+     * type of the first argument instead.
+     */
+    public static final SqlReturnTypeInference IF_NULLABLE =
+            ReturnTypes.cascade(
+                    new SqlReturnTypeInference() {
+                        @Override
+                        public @Nullable RelDataType 
inferReturnType(SqlOperatorBinding opBinding) {
+                            int nOperands = opBinding.getOperandCount();
+                            List<RelDataType> types = new ArrayList<>();
+                            for (int i = 1; i < nOperands; i++) {
+                                RelDataType type = opBinding.getOperandType(i);
+                                // the RelDataTypeFactory.leastRestrictive() 
will check that all
+                                // types are identical.
+                                if (SqlTypeUtil.isNumeric(type)
+                                        || SqlTypeUtil.isCharacter(type)
+                                        || SqlTypeUtil.isBinary(type)) {
+                                    types.add(type);
+                                } else {
+                                    return opBinding.getOperandType(1);
+                                }
+                            }
+                            return 
opBinding.getTypeFactory().leastRestrictive(types);
+                        }
+                    },
+                    SqlTypeTransforms.TO_NULLABLE);
 
     public static final SqlReturnTypeInference STR_MAP_NULLABLE =
             ReturnTypes.explicit(
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/type/NumericOrDefaultReturnTypeInference.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/type/NumericOrDefaultReturnTypeInference.java
deleted file mode 100644
index 65cf3a8a371..00000000000
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/type/NumericOrDefaultReturnTypeInference.java
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.table.planner.plan.type;
-
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.sql.SqlOperatorBinding;
-import org.apache.calcite.sql.type.SqlReturnTypeInference;
-import org.apache.calcite.sql.type.SqlTypeUtil;
-
-import java.util.ArrayList;
-import java.util.List;
-
-/**
- * Determine the return type of functions with numeric arguments. The return 
type is the type of the
- * argument with the largest range. We start to consider the arguments from 
the `startTypeIdx`-th
- * one. If one of the arguments is not of numeric type, we return the type of 
the
- * `defaultTypeIdx`-th argument instead.
- */
-public class NumericOrDefaultReturnTypeInference implements 
SqlReturnTypeInference {
-    // Default argument whose type is returned
-    // when one of the arguments from the `startTypeIdx`-th isn't of numeric 
type.
-    private int defaultTypeIdx;
-    // We check from the `startTypeIdx`-th argument that
-    // if all the following arguments are of numeric type.
-    // Previous arguments are ignored.
-    private int startTypeIdx;
-
-    public NumericOrDefaultReturnTypeInference(int defaultTypeIdx) {
-        this(defaultTypeIdx, 0);
-    }
-
-    public NumericOrDefaultReturnTypeInference(int defaultTypeIdx, int 
startTypeIdx) {
-        this.defaultTypeIdx = defaultTypeIdx;
-        this.startTypeIdx = startTypeIdx;
-    }
-
-    @Override
-    public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
-        int nOperands = opBinding.getOperandCount();
-        List<RelDataType> types = new ArrayList<>();
-        for (int i = startTypeIdx; i < nOperands; i++) {
-            RelDataType type = opBinding.getOperandType(i);
-            if (SqlTypeUtil.isNumeric(type)) {
-                types.add(type);
-            } else {
-                return opBinding.getOperandType(defaultTypeIdx);
-            }
-        }
-        return opBinding.getTypeFactory().leastRestrictive(types);
-    }
-}
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala
index 9c5e90dbee1..325b06197c6 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/IfCallGen.scala
@@ -52,19 +52,19 @@ class IfCallGen() extends CallGenerator {
     val resultCode =
       s"""
          |// --- Start code generated by ${className[IfCallGen]}
-         |${castedResultTerm1.getCode}
-         |${castedResultTerm2.getCode}
          |${operands.head.code}
          |$resultTerm = $resultDefault;
          |if (${operands.head.resultTerm}) {
          |  ${operands(1).code}
          |  if (!${operands(1).nullTerm}) {
+         |    ${castedResultTerm1.getCode}
          |    $resultTerm = ${castedResultTerm1.getReturnTerm};
          |  }
          |  $nullTerm = ${operands(1).nullTerm};
          |} else {
          |  ${operands(2).code}
          |  if (!${operands(2).nullTerm}) {
+         |    ${castedResultTerm2.getCode}
          |    $resultTerm = ${castedResultTerm2.getReturnTerm};
          |  }
          |  $nullTerm = ${operands(2).nullTerm};
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CalcITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CalcITCase.scala
index bd4cdf3e9c3..5ebd4acb2ea 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CalcITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/CalcITCase.scala
@@ -366,6 +366,31 @@ class CalcITCase extends StreamingTestBase {
     sink.getAppendResults.foreach(result => assertEquals(expected, result))
   }
 
+  @Test
+  def testIfFunction(): Unit = {
+    val testDataId = TestValuesTableFactory.registerData(TestData.data1)
+    val ddl =
+      s"""
+         |CREATE TABLE t (
+         |  a int,
+         |  b varchar,
+         |  c int
+         |) WITH (
+         |  'connector' = 'values',
+         |  'data-id' = '$testDataId',
+         |  'bounded' = 'true'
+         |)
+         |""".stripMargin
+    tEnv.executeSql(ddl)
+    val expected = List("false,1", "false,2", "false,3", "true,4", "true,5", 
"true,6")
+    val actual = tEnv
+      .executeSql("SELECT IF(a > 3, 'true', 'false'), a from t")
+      .collect()
+      .map(r => r.toString)
+      .toList
+    assertEquals(expected.sorted, actual.sorted)
+  }
+
   @Test
   def testSourceWithCustomInternalData(): Unit = {
 

Reply via email to