This is an automated email from the ASF dual-hosted git repository. mbudiu pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push: new c49792f9c7 [CALCITE-6032] Multilevel correlated query is failing in RelDecorrelator code path c49792f9c7 is described below commit c49792f9c72159571f898c5fca1e26cba9870b07 Author: Hanumath Maduri <hanu....@gmail.com> AuthorDate: Fri Jan 19 11:14:46 2024 -0800 [CALCITE-6032] Multilevel correlated query is failing in RelDecorrelator code path --- .../java/org/apache/calcite/plan/RelOptUtil.java | 20 ++++++- .../main/java/org/apache/calcite/rex/RexUtil.java | 23 ++++++++ .../apache/calcite/sql2rel/RelFieldTrimmer.java | 20 ++++++- .../apache/calcite/sql2rel/SqlToRelConverter.java | 29 +++++++-- .../java/org/apache/calcite/test/JdbcTest.java | 69 ++++++++++++++++++++++ 5 files changed, 154 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java index e8c9a48809..5387acd8cf 100644 --- a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java +++ b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java @@ -271,7 +271,7 @@ public abstract class RelOptUtil { } /** - * Returns a set of variables used by a relational expression or its + * Returns the set of variables used by a relational expression or its * descendants. * * <p>The set may contain "duplicates" (variables with different ids that, @@ -286,6 +286,24 @@ public abstract class RelOptUtil { return visitor.vuv.variables; } + /** + * Returns the set of variables used by the given list of sub-queries and its descendants. + * + * @param subQueries The sub-queries containing correlation variables + * @return A list of correlation identifiers found within the sub-queries. + * The type of the [CorrelationId] parameter corresponds to + * {@link org.apache.calcite.rex.RexCorrelVariable#id}. + */ + public static Set<CorrelationId> getVariablesUsed(List<RexSubQuery> subQueries) { + // Internally this function calls getVariablesUsed on a RelNode to get all the + // correlated variables in that RelNode + Set<CorrelationId> correlationIds = new HashSet<>(); + for (RexSubQuery subQ : subQueries) { + correlationIds.addAll(getVariablesUsed(subQ.rel)); + } + return correlationIds; + } + /** Finds which columns of a correlation variable are used within a * relational expression. */ public static ImmutableBitSet correlationColumns(CorrelationId id, diff --git a/core/src/main/java/org/apache/calcite/rex/RexUtil.java b/core/src/main/java/org/apache/calcite/rex/RexUtil.java index f9390a5b3c..5d7f473a20 100644 --- a/core/src/main/java/org/apache/calcite/rex/RexUtil.java +++ b/core/src/main/java/org/apache/calcite/rex/RexUtil.java @@ -2854,6 +2854,29 @@ public class RexUtil { } } + /** Visitor that collects all the top level SubQueries {@link RexSubQuery} + * in a projection list of a given {@link Project}.*/ + public static class SubQueryCollector extends RexVisitorImpl<Void> { + private List<RexSubQuery> subQueries; + private SubQueryCollector() { + super(true); + this.subQueries = new ArrayList<>(); + } + + @Override public Void visitSubQuery(RexSubQuery subQuery) { + subQueries.add(subQuery); + return null; + } + + public static List<RexSubQuery> collect(Project project) { + SubQueryCollector subQueryCollector = new SubQueryCollector(); + for (RexNode node : project.getProjects()) { + node.accept(subQueryCollector); + } + return subQueryCollector.subQueries; + } + } + /** Visitor that throws {@link org.apache.calcite.util.Util.FoundOne} if * applied to an expression that contains a {@link RexSubQuery}. */ public static class SubQueryFinder extends RexVisitorImpl<Void> { diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java index bfa69d0a4d..a8d99126ea 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java @@ -55,6 +55,7 @@ import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexPermuteInputsShuttle; import org.apache.calcite.rex.RexProgram; +import org.apache.calcite.rex.RexSubQuery; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.rex.RexVisitor; import org.apache.calcite.sql.SqlExplainFormat; @@ -489,7 +490,24 @@ public class RelFieldTrimmer implements ReflectiveVisitor { ord.e.accept(inputFinder); } } - ImmutableBitSet inputFieldsUsed = inputFinder.build(); + + // Collect all the SubQueries in the projection list. + List<RexSubQuery> subQueries = RexUtil.SubQueryCollector.collect(project); + // Get all the correlationIds present in the SubQueries + Set<CorrelationId> correlationIds = RelOptUtil.getVariablesUsed(subQueries); + ImmutableBitSet requiredColumns = ImmutableBitSet.of(); + if (correlationIds.size() > 0) { + assert correlationIds.size() == 1; + // Correlation columns are also needed by SubQueries, so add them to inputFieldsUsed. + requiredColumns = RelOptUtil.correlationColumns(correlationIds.iterator().next(), project); + } + + ImmutableBitSet finderFields = inputFinder.build(); + + ImmutableBitSet inputFieldsUsed = ImmutableBitSet.builder() + .addAll(requiredColumns) + .addAll(finderFields) + .build(); // Create input with trimmed columns. TrimResult trimResult = diff --git a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java index 0998c05403..2b79213838 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java @@ -5491,18 +5491,25 @@ public class SqlToRelConverter { builder.add(convertExpression(node)); } final ImmutableList<RexNode> list = builder.build(); + RelNode rel = root.rel; + // Fix the correlation namespaces and de-duplicate the correlation variables. + CorrelationUse correlationUse = getCorrelationUse(this, root.rel); + if (correlationUse != null) { + rel = correlationUse.r; + } + switch (kind) { case IN: - return RexSubQuery.in(root.rel, list); + return RexSubQuery.in(rel, list); case NOT_IN: return rexBuilder.makeCall(SqlStdOperatorTable.NOT, - RexSubQuery.in(root.rel, list)); + RexSubQuery.in(rel, list)); case SOME: - return RexSubQuery.some(root.rel, list, + return RexSubQuery.some(rel, list, (SqlQuantifyOperator) call.getOperator()); case ALL: return rexBuilder.makeCall(SqlStdOperatorTable.NOT, - RexSubQuery.some(root.rel, list, + RexSubQuery.some(rel, list, negate((SqlQuantifyOperator) call.getOperator()))); default: throw new AssertionError(kind); @@ -5515,6 +5522,12 @@ public class SqlToRelConverter { query = Iterables.getOnlyElement(call.getOperandList()); root = convertQueryRecursive(query, false, null); RelNode rel = root.rel; + // Fix the correlation namespaces and de-duplicate the correlation variables. + CorrelationUse correlationUse = getCorrelationUse(this, root.rel); + if (correlationUse != null) { + rel = correlationUse.r; + } + while (rel instanceof Project || rel instanceof Sort && ((Sort) rel).fetch == null @@ -5533,7 +5546,13 @@ public class SqlToRelConverter { call = (SqlCall) expr; query = Iterables.getOnlyElement(call.getOperandList()); root = convertQueryRecursive(query, false, null); - return RexSubQuery.scalar(root.rel); + rel = root.rel; + // Fix the correlation namespaces and de-duplicate the correlation variables. + correlationUse = getCorrelationUse(this, root.rel); + if (correlationUse != null) { + rel = correlationUse.r; + } + return RexSubQuery.scalar(rel); case ARRAY_QUERY_CONSTRUCTOR: call = (SqlCall) expr; diff --git a/core/src/test/java/org/apache/calcite/test/JdbcTest.java b/core/src/test/java/org/apache/calcite/test/JdbcTest.java index 5cf1177f7f..af4bf13ae8 100644 --- a/core/src/test/java/org/apache/calcite/test/JdbcTest.java +++ b/core/src/test/java/org/apache/calcite/test/JdbcTest.java @@ -8270,6 +8270,75 @@ public class JdbcTest { .returns("EXPR$0=[1, 1.1]\n"); } + /** Test case for + * <a href="https://issues.apache.org/jira/browse/CALCITE-6032">[CALCITE-6032] + * NullPointerException in Reldecorrelator for a Multi level correlated subquery</a>. */ + @Test void testMultiLevelDecorrelation() throws Exception { + String hsqldbMemUrl = "jdbc:hsqldb:mem:."; + Connection baseConnection = DriverManager.getConnection(hsqldbMemUrl); + Statement baseStmt = baseConnection.createStatement(); + baseStmt.execute("create table invoice (inv_id integer, col1\n" + + "integer, inv_amt integer)"); + baseStmt.execute("create table item(item_id integer, item_amt\n" + + "integer, item_col1 integer, item_col2 integer, item_col3\n" + + "integer,item_col4 integer )"); + baseStmt.execute("INSERT INTO invoice VALUES (1, 1, 1)"); + baseStmt.execute("INSERT INTO invoice VALUES (2, 2, 2)"); + baseStmt.execute("INSERT INTO invoice VALUES (3, 3, 3)"); + baseStmt.execute("INSERT INTO item values (1, 1, 1, 1, 1, 1)"); + baseStmt.execute("INSERT INTO item values (2, 2, 2, 2, 2, 2)"); + baseStmt.close(); + baseConnection.commit(); + + Properties info = new Properties(); + info.put("model", + "inline:" + + "{\n" + + " version: '1.0',\n" + + " defaultSchema: 'BASEJDBC',\n" + + " schemas: [\n" + + " {\n" + + " type: 'jdbc',\n" + + " name: 'BASEJDBC',\n" + + " jdbcDriver: '" + jdbcDriver.class.getName() + "',\n" + + " jdbcUrl: '" + hsqldbMemUrl + "',\n" + + " jdbcCatalog: null,\n" + + " jdbcSchema: null\n" + + " }\n" + + " ]\n" + + "}"); + + Connection calciteConnection = + DriverManager.getConnection("jdbc:calcite:", info); + + String statement = "SELECT Sum(invoice.inv_amt * (\n" + + " SELECT max(mainrate.item_id + mainrate.item_amt)\n" + + " FROM item AS mainrate\n" + + " WHERE mainrate.item_col1 is not null\n" + + " AND mainrate.item_col2 is not null\n" + + " AND mainrate.item_col3 = invoice.col1\n" + + " AND mainrate.item_col4 = (\n" + + " SELECT max(cr.item_col4)\n" + + " FROM item AS cr\n" + + " WHERE cr.item_col3 = mainrate.item_col3\n" + + " AND cr.item_col1 =\n" + + "mainrate.item_col1\n" + + " AND cr.item_col2 =\n" + + "mainrate.item_col2 \n" + + " AND cr.item_col4 <=\n" + + "invoice.inv_id))) AS invamount,\n" + + "count(*) AS invcount\n" + + "FROM invoice\n" + + "WHERE invoice.inv_amt < 10 AND invoice.inv_amt > 0"; + ResultSet rs = calciteConnection.prepareStatement(statement).executeQuery(); + assert rs.next(); + assertEquals(rs.getInt(1), 10); + assertEquals(rs.getInt(2), 3); + assert !rs.next(); + rs.close(); + calciteConnection.close(); + } + /** Test case for * <a href="https://issues.apache.org/jira/browse/CALCITE-5414">[CALCITE-5414]</a> * Convert between standard Gregorian and proleptic Gregorian calendars for