This is an automated email from the ASF dual-hosted git repository.
mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 47ba3d5314 [CALCITE-6742] StandardConvertletTable.convertCall loses
casts from ROW comparisons
47ba3d5314 is described below
commit 47ba3d531493e34b063adb51842c9b8984bf796d
Author: Mihai Budiu <[email protected]>
AuthorDate: Fri Dec 20 17:08:58 2024 -0800
[CALCITE-6742] StandardConvertletTable.convertCall loses casts from ROW
comparisons
Signed-off-by: Mihai Budiu <[email protected]>
---
.../adapter/enumerable/RexToLixTranslator.java | 7 ++-
.../calcite/sql2rel/StandardConvertletTable.java | 61 +++++++++++++++++++---
.../apache/calcite/test/SqlToRelConverterTest.java | 17 ++++++
.../apache/calcite/test/SqlToRelConverterTest.xml | 22 ++++++++
4 files changed, 100 insertions(+), 7 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
index e38ac3210b..d63c512e88 100644
---
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
+++
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
@@ -563,9 +563,14 @@ public class RexToLixTranslator implements
RexVisitor<RexToLixTranslator.Result>
case TINYINT:
case SMALLINT: {
if (SqlTypeName.NUMERIC_TYPES.contains(sourceType.getSqlTypeName())) {
+ Type javaClass = typeFactory.getJavaClass(targetType);
+ Primitive primitive = Primitive.of(javaClass);
+ if (primitive == null) {
+ primitive = Primitive.ofBox(javaClass);
+ }
return Expressions.call(
BuiltInMethod.INTEGER_CAST_ROUNDING_MODE.method,
-
Expressions.constant(Primitive.of(typeFactory.getJavaClass(targetType))),
+ Expressions.constant(primitive),
operand,
Expressions.constant(typeFactory.getTypeSystem().roundingMode()));
}
return defaultExpression.get();
diff --git
a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
index 7d9c807cd0..ff0afa6037 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
@@ -90,6 +90,7 @@ import org.checkerframework.checker.nullness.qual.Nullable;
import java.math.BigDecimal;
import java.math.RoundingMode;
+import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
@@ -100,6 +101,7 @@ import java.util.stream.Collectors;
import static com.google.common.base.Preconditions.checkArgument;
+import static org.apache.calcite.sql.fun.SqlStdOperatorTable.CAST;
import static
org.apache.calcite.sql.fun.SqlStdOperatorTable.QUANTIFY_OPERATORS;
import static
org.apache.calcite.sql.type.NonNullableAccessors.getComponentTypeOrThrow;
import static org.apache.calcite.util.Util.first;
@@ -143,7 +145,7 @@ public class StandardConvertletTable extends
ReflectiveConvertletTable {
addAlias(SqlLibraryOperators.BITOR_AGG, SqlStdOperatorTable.BIT_OR);
// Register convertlets for specific objects.
- registerOp(SqlStdOperatorTable.CAST, this::convertCast);
+ registerOp(CAST, this::convertCast);
registerOp(SqlLibraryOperators.SAFE_CAST, this::convertCast);
registerOp(SqlLibraryOperators.TRY_CAST, this::convertCast);
registerOp(SqlLibraryOperators.INFIX_CAST, this::convertCast);
@@ -1151,14 +1153,61 @@ public class StandardConvertletTable extends
ReflectiveConvertletTable {
// Expand 'ROW (x0, x1, ...) = ROW (y0, y1, ...)'
// to 'x0 = y0 AND x1 = y1 AND ...'
+ // If there are casts to ROW, apply them fieldwise:
+ // 'CAST(ROW(x0, x1) AS (t0, t1)) = ROW(y0, y1)' becomes
+ // 'CAST(x0 as t0) = y0 AND CAST(x1 as t1) = y1'
if (op.kind == SqlKind.EQUALS) {
- final RexNode expr0 = RexUtil.removeCast(exprs.get(0));
- final RexNode expr1 = RexUtil.removeCast(exprs.get(1));
+ // For every cast one list with the types of all fields
+ ArrayDeque<List<RelDataTypeField>> leftTypes = new ArrayDeque<>();
+ ArrayDeque<List<RelDataTypeField>> rightTypes = new ArrayDeque<>();
+ RexNode expr0 = exprs.get(0);
+ RexNode expr1 = exprs.get(1);
+ while (expr0.getKind() == SqlKind.CAST) {
+ RexCall cast = (RexCall) expr0;
+ if (cast.type.getSqlTypeName() == SqlTypeName.ROW) {
+ expr0 = ((RexCall) expr0).operands.get(0);
+ leftTypes.add(cast.type.getFieldList());
+ } else {
+ break;
+ }
+ }
+ while (expr1.getKind() == SqlKind.CAST) {
+ RexCall cast = (RexCall) expr1;
+ if (cast.type.getSqlTypeName() == SqlTypeName.ROW) {
+ expr1 = ((RexCall) expr1).operands.get(0);
+ rightTypes.add(cast.type.getFieldList());
+ } else {
+ break;
+ }
+ }
if (expr0.getKind() == SqlKind.ROW && expr1.getKind() == SqlKind.ROW) {
final RexCall call0 = (RexCall) expr0;
final RexCall call1 = (RexCall) expr1;
+
+ List<RexNode> expr0Operands = call0.getOperands();
+ // Insert the casts in reverse order
+ while (!leftTypes.isEmpty()) {
+ List<RexNode> converted = new ArrayList<>();
+ List<RelDataTypeField> types = leftTypes.removeLast();
+ Pair.forEach(types, expr0Operands, (x, y) ->
+ converted.add(
+ rexBuilder.makeAbstractCast(
+ call.getParserPosition(), x.getType(), y, false)));
+ expr0Operands = converted;
+ }
+ List<RexNode> expr1Operands = call1.getOperands();
+ while (!rightTypes.isEmpty()) {
+ List<RexNode> converted = new ArrayList<>();
+ List<RelDataTypeField> types = rightTypes.removeLast();
+ Pair.forEach(types, expr1Operands, (x, y) ->
+ converted.add(
+ rexBuilder.makeAbstractCast(
+ call.getParserPosition(), x.getType(), y, false)));
+ expr1Operands = converted;
+ }
+
final List<RexNode> eqList = new ArrayList<>();
- Pair.forEach(call0.getOperands(), call1.getOperands(), (x, y) ->
+ Pair.forEach(expr0Operands, expr1Operands, (x, y) ->
eqList.add(rexBuilder.makeCall(call.getParserPosition(), op, x,
y)));
return RexUtil.composeConjunction(rexBuilder, eqList);
}
@@ -1680,7 +1729,7 @@ public class StandardConvertletTable extends
ReflectiveConvertletTable {
if (argRex == null || argRex.getType().equals(varType)) {
return argInput;
}
- return SqlStdOperatorTable.CAST.createCall(pos, argInput,
+ return CAST.createCall(pos, argInput,
SqlTypeUtil.convertTypeToSpec(varType));
}
}
@@ -1852,7 +1901,7 @@ public class StandardConvertletTable extends
ReflectiveConvertletTable {
if (argRex == null || argRex.getType().equals(varType)) {
return argInput;
}
- return SqlStdOperatorTable.CAST.createCall(pos, argInput,
+ return CAST.createCall(pos, argInput,
SqlTypeUtil.convertTypeToSpec(varType));
}
}
diff --git
a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
index 90ed3b69eb..d97a73f163 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
@@ -852,6 +852,23 @@ class SqlToRelConverterTest extends SqlToRelTestBase {
sql(sql).ok();
}
+ /** Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-6742">[CALCITE-6742]
+ * StandardConvertletTable.convertCall loses casts from ROW comparisons</a>.
*/
+ @Test void testStructCast() {
+ final String sql = "select ROW(1, 'x') = ROW('y', 1)";
+ sql(sql).ok();
+ }
+
+ /** Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-6742">[CALCITE-6742]
+ * StandardConvertletTable.convertCall loses casts from ROW comparisons</a>.
*/
+ @Test void testStructCast1() {
+ final String sql = "select CAST(CAST(ROW('x', 1) AS "
+ + "ROW(l INTEGER, r DOUBLE)) AS ROW(l BIGINT, r INTEGER)) =
ROW(RAND(), RAND())";
+ sql(sql).ok();
+ }
+
/** As {@link #testSelectOverDistinct()} but for streaming queries. */
@Test void testSelectStreamPartitionDistinct() {
final String sql = "select stream\n"
diff --git
a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
index 98ab6f3fe5..49a6200e7d 100644
--- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
@@ -7691,6 +7691,28 @@ from orders]]>
LogicalDelta
LogicalProject(ROWTIME=[$0], PRODUCTID=[$1], ORDERID=[$2], C=[COUNT() OVER
(PARTITION BY $1 ORDER BY $0 RANGE 1000:INTERVAL SECOND PRECEDING)])
LogicalTableScan(table=[[CATALOG, SALES, ORDERS]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testStructCast">
+ <Resource name="sql">
+ <![CDATA[select ROW(1, 'x') = ROW('y', 1)]]>
+ </Resource>
+ <Resource name="plan">
+ <![CDATA[
+LogicalProject(EXPR$0=[AND(=(1, CAST('y'):INTEGER NOT NULL),
=(CAST('x'):INTEGER NOT NULL, 1))])
+ LogicalValues(tuples=[[{ 0 }]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testStructCast1">
+ <Resource name="sql">
+ <![CDATA[select CAST(CAST(ROW('x', 1) AS ROW(l INTEGER, r DOUBLE)) AS
ROW(l BIGINT, r INTEGER)) = ROW(RAND(), RAND())]]>
+ </Resource>
+ <Resource name="plan">
+ <![CDATA[
+LogicalProject(EXPR$0=[AND(=(CAST(CAST('x'):INTEGER NOT NULL):DOUBLE NOT NULL,
RAND()), =(1.0E0, RAND()))])
+ LogicalValues(tuples=[[{ 0 }]])
]]>
</Resource>
</TestCase>