This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit 3775307e0b490f1bdad1555809d6bace073c2547
Author: Ran Tao <[email protected]>
AuthorDate: Sun Aug 27 14:33:36 2023 +0800

    [CALCITE-5960] CAST throws NullPointerException if SqlTypeFamily of 
targetType is null
    
    Add test case (Julian Hyde).
---
 .../adapter/enumerable/RexToLixTranslator.java     | 59 +++++++++-------------
 core/src/test/resources/sql/misc.iq                | 12 +++++
 .../org/apache/calcite/test/SqlOperatorTest.java   | 23 ++++++++-
 3 files changed, 56 insertions(+), 38 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
index 24c1d02748..bea44379b6 100644
--- 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
+++ 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
@@ -290,7 +290,7 @@ public class RexToLixTranslator implements 
RexVisitor<RexToLixTranslator.Result>
     Expression convert = getConvertExpression(sourceType, targetType, operand);
     Expression convert2 = checkExpressionPadTruncate(convert, sourceType, 
targetType);
     Expression convert3 = expressionHandlingSafe(convert2, safe);
-    return scaleIntervalToNumber(sourceType, targetType, convert3);
+    return scaleValue(sourceType, targetType, convert3);
   }
 
   private Expression getConvertExpression(
@@ -969,38 +969,23 @@ public class RexToLixTranslator implements 
RexVisitor<RexToLixTranslator.Result>
     return root;
   }
 
-  private static Expression scaleIntervalToNumber(
+  /** If an expression is a {@code NUMERIC} derived from an {@code INTERVAL},
+   * scales it appropriately; returns the operand unchanged if the conversion
+   * is not from {@code INTERVAL} to {@code NUMERIC}. */
+  private static Expression scaleValue(
       RelDataType sourceType,
       RelDataType targetType,
       Expression operand) {
-    switch (requireNonNull(targetType.getSqlTypeName().getFamily(),
-        () -> "SqlTypeFamily for " + targetType)) {
-    case NUMERIC:
-      switch (sourceType.getSqlTypeName()) {
-      case INTERVAL_YEAR:
-      case INTERVAL_YEAR_MONTH:
-      case INTERVAL_MONTH:
-      case INTERVAL_DAY:
-      case INTERVAL_DAY_HOUR:
-      case INTERVAL_DAY_MINUTE:
-      case INTERVAL_DAY_SECOND:
-      case INTERVAL_HOUR:
-      case INTERVAL_HOUR_MINUTE:
-      case INTERVAL_HOUR_SECOND:
-      case INTERVAL_MINUTE:
-      case INTERVAL_MINUTE_SECOND:
-      case INTERVAL_SECOND:
-        // Scale to the given field.
-        final BigDecimal multiplier = BigDecimal.ONE;
-        final BigDecimal divider =
-            sourceType.getSqlTypeName().getEndUnit().multiplier;
-        return RexImpTable.multiplyDivide(operand, multiplier, divider);
-      default:
-        break;
-      }
-      break;
-    default:
-      break;
+    final SqlTypeFamily targetFamily = targetType.getSqlTypeName().getFamily();
+    final SqlTypeFamily sourceFamily = sourceType.getSqlTypeName().getFamily();
+    if (targetFamily == SqlTypeFamily.NUMERIC
+        && (sourceFamily == SqlTypeFamily.INTERVAL_YEAR_MONTH
+            || sourceFamily == SqlTypeFamily.INTERVAL_DAY_TIME)) {
+      // Scale to the given field.
+      final BigDecimal multiplier = BigDecimal.ONE;
+      final BigDecimal divider =
+          sourceType.getSqlTypeName().getEndUnit().multiplier;
+      return RexImpTable.multiplyDivide(operand, multiplier, divider);
     }
     return operand;
   }
@@ -1009,13 +994,15 @@ public class RexToLixTranslator implements 
RexVisitor<RexToLixTranslator.Result>
    * Visit {@code RexInputRef}. If it has never been visited
    * under current storage type before, {@code RexToLixTranslator}
    * generally produces three lines of code.
-   * For example, when visiting a column (named commission) in
+   *
+   * <p>For example, when visiting a column (named commission) in
    * table Employee, the generated code snippet is:
-   * {@code
-   *   final Employee current =(Employee) inputEnumerator.current();
-       final Integer input_value = current.commission;
-       final boolean input_isNull = input_value == null;
-   * }
+   *
+   * <blockquote><pre>{@code
+   * final Employee current = (Employee) inputEnumerator.current();
+   * final Integer input_value = current.commission;
+   * final boolean input_isNull = input_value == null;
+   * }</pre></blockquote>
    */
   @Override public Result visitInputRef(RexInputRef inputRef) {
     final Pair<RexNode, @Nullable Type> key = Pair.of(inputRef, 
currentStorageType);
diff --git a/core/src/test/resources/sql/misc.iq 
b/core/src/test/resources/sql/misc.iq
index 321843c183..65fe7f085d 100644
--- a/core/src/test/resources/sql/misc.iq
+++ b/core/src/test/resources/sql/misc.iq
@@ -1420,6 +1420,18 @@ from "scott".emp;
 
 !ok
 
+# [CALCITE-5960] CAST failed if SqlTypeFamily of targetType is NULL
+# Cast row
+SELECT cast(row(1, 2) as row(a integer, b tinyint)) as r;
++--------+
+| R      |
++--------+
+| {1, 2} |
++--------+
+(1 row)
+
+!ok
+
 # [CALCITE-877] Allow ROW as argument to COLLECT
 select deptno, collect(r) as empnos
 from (select deptno, (empno, deptno) as r
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java 
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index f322eb6b9c..0e23f66669 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -1389,10 +1389,29 @@ public class SqlOperatorTest {
         INVALID_CHAR_MESSAGE, true);
   }
 
+  @Test void testCastRowType() {
+    final SqlOperatorFixture f = fixture();
+    f.checkScalar("cast((1, 2) as row(f0 integer, f1 bigint))",
+        "{1, 2}",
+        "RecordType(INTEGER NOT NULL F0, BIGINT NOT NULL F1) NOT NULL");
+    f.checkScalar("cast((1, 2) as row(f0 integer, f1 decimal(2)))",
+        "{1, 2}",
+        "RecordType(INTEGER NOT NULL F0, DECIMAL(2, 0) NOT NULL F1) NOT NULL");
+    f.checkScalar("cast((1, '2') as row(f0 integer, f1 varchar))",
+        "{1, 2}",
+        "RecordType(INTEGER NOT NULL F0, VARCHAR NOT NULL F1) NOT NULL");
+    f.checkScalar("cast(('A', 'B') as row(f0 varchar, f1 varchar))",
+        "{A, B}",
+        "RecordType(VARCHAR NOT NULL F0, VARCHAR NOT NULL F1) NOT NULL");
+    f.checkNull("cast(null as row(f0 integer, f1 bigint))");
+    f.checkNull("cast(null as row(f0 integer, f1 decimal(2,0)))");
+    f.checkNull("cast(null as row(f0 integer, f1 varchar))");
+    f.checkNull("cast(null as row(f0 varchar, f1 varchar))");
+  }
+
   /** Test case for
    * <a 
href="https://issues.apache.org/jira/browse/CALCITE-4861";>[CALCITE-4861]
-   * Optimisation of chained cast calls can lead to unexpected behaviour.</a>.
-   */
+   * Optimization of chained CAST calls leads to unexpected behavior</a>. */
   @Test void testChainedCast() {
     final SqlOperatorFixture f = fixture();
     f.checkFails("CAST(CAST(CAST(123456 AS TINYINT) AS INT) AS BIGINT)",

Reply via email to