This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new c49792f9c7 [CALCITE-6032] Multilevel correlated query is failing in 
RelDecorrelator code path
c49792f9c7 is described below

commit c49792f9c72159571f898c5fca1e26cba9870b07
Author: Hanumath Maduri <hanu....@gmail.com>
AuthorDate: Fri Jan 19 11:14:46 2024 -0800

    [CALCITE-6032] Multilevel correlated query is failing in RelDecorrelator 
code path
---
 .../java/org/apache/calcite/plan/RelOptUtil.java   | 20 ++++++-
 .../main/java/org/apache/calcite/rex/RexUtil.java  | 23 ++++++++
 .../apache/calcite/sql2rel/RelFieldTrimmer.java    | 20 ++++++-
 .../apache/calcite/sql2rel/SqlToRelConverter.java  | 29 +++++++--
 .../java/org/apache/calcite/test/JdbcTest.java     | 69 ++++++++++++++++++++++
 5 files changed, 154 insertions(+), 7 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java 
b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
index e8c9a48809..5387acd8cf 100644
--- a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
+++ b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
@@ -271,7 +271,7 @@ public abstract class RelOptUtil {
   }
 
   /**
-   * Returns a set of variables used by a relational expression or its
+   * Returns the set of variables used by a relational expression or its
    * descendants.
    *
    * <p>The set may contain "duplicates" (variables with different ids that,
@@ -286,6 +286,24 @@ public abstract class RelOptUtil {
     return visitor.vuv.variables;
   }
 
+  /**
+   * Returns the set of variables used by the given list of sub-queries and 
its descendants.
+   *
+   * @param subQueries The sub-queries containing correlation variables
+   * @return A list of correlation identifiers found within the sub-queries.
+   *          The type of the [CorrelationId] parameter corresponds to
+   *          {@link org.apache.calcite.rex.RexCorrelVariable#id}.
+   */
+  public static Set<CorrelationId> getVariablesUsed(List<RexSubQuery> 
subQueries) {
+    // Internally this function calls getVariablesUsed on a RelNode to get all 
the
+    // correlated variables in that RelNode
+    Set<CorrelationId> correlationIds = new HashSet<>();
+    for (RexSubQuery subQ : subQueries) {
+      correlationIds.addAll(getVariablesUsed(subQ.rel));
+    }
+    return correlationIds;
+  }
+
   /** Finds which columns of a correlation variable are used within a
    * relational expression. */
   public static ImmutableBitSet correlationColumns(CorrelationId id,
diff --git a/core/src/main/java/org/apache/calcite/rex/RexUtil.java 
b/core/src/main/java/org/apache/calcite/rex/RexUtil.java
index f9390a5b3c..5d7f473a20 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexUtil.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexUtil.java
@@ -2854,6 +2854,29 @@ public class RexUtil {
     }
   }
 
+  /** Visitor that collects all the top level SubQueries {@link RexSubQuery}
+   *  in a projection list of a given {@link Project}.*/
+  public static class SubQueryCollector extends RexVisitorImpl<Void> {
+    private List<RexSubQuery> subQueries;
+    private SubQueryCollector() {
+      super(true);
+      this.subQueries = new ArrayList<>();
+    }
+
+    @Override public Void visitSubQuery(RexSubQuery subQuery) {
+      subQueries.add(subQuery);
+      return null;
+    }
+
+    public static List<RexSubQuery> collect(Project project) {
+      SubQueryCollector subQueryCollector = new SubQueryCollector();
+      for (RexNode node : project.getProjects()) {
+        node.accept(subQueryCollector);
+      }
+      return subQueryCollector.subQueries;
+    }
+  }
+
   /** Visitor that throws {@link org.apache.calcite.util.Util.FoundOne} if
    * applied to an expression that contains a {@link RexSubQuery}. */
   public static class SubQueryFinder extends RexVisitorImpl<Void> {
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java 
b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
index bfa69d0a4d..a8d99126ea 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelFieldTrimmer.java
@@ -55,6 +55,7 @@ import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexPermuteInputsShuttle;
 import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexSubQuery;
 import org.apache.calcite.rex.RexUtil;
 import org.apache.calcite.rex.RexVisitor;
 import org.apache.calcite.sql.SqlExplainFormat;
@@ -489,7 +490,24 @@ public class RelFieldTrimmer implements ReflectiveVisitor {
         ord.e.accept(inputFinder);
       }
     }
-    ImmutableBitSet inputFieldsUsed = inputFinder.build();
+
+    // Collect all the SubQueries in the projection list.
+    List<RexSubQuery> subQueries = RexUtil.SubQueryCollector.collect(project);
+    // Get all the correlationIds present in the SubQueries
+    Set<CorrelationId> correlationIds = 
RelOptUtil.getVariablesUsed(subQueries);
+    ImmutableBitSet requiredColumns = ImmutableBitSet.of();
+    if (correlationIds.size() > 0) {
+      assert correlationIds.size() == 1;
+      // Correlation columns are also needed by SubQueries, so add them to 
inputFieldsUsed.
+      requiredColumns = 
RelOptUtil.correlationColumns(correlationIds.iterator().next(), project);
+    }
+
+    ImmutableBitSet finderFields = inputFinder.build();
+
+    ImmutableBitSet inputFieldsUsed = ImmutableBitSet.builder()
+        .addAll(requiredColumns)
+        .addAll(finderFields)
+        .build();
 
     // Create input with trimmed columns.
     TrimResult trimResult =
diff --git 
a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java 
b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
index 0998c05403..2b79213838 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
@@ -5491,18 +5491,25 @@ public class SqlToRelConverter {
               builder.add(convertExpression(node));
             }
             final ImmutableList<RexNode> list = builder.build();
+            RelNode rel = root.rel;
+            // Fix the correlation namespaces and de-duplicate the correlation 
variables.
+            CorrelationUse correlationUse = getCorrelationUse(this, root.rel);
+            if (correlationUse != null) {
+              rel = correlationUse.r;
+            }
+
             switch (kind) {
             case IN:
-              return RexSubQuery.in(root.rel, list);
+              return RexSubQuery.in(rel, list);
             case NOT_IN:
               return rexBuilder.makeCall(SqlStdOperatorTable.NOT,
-                  RexSubQuery.in(root.rel, list));
+                  RexSubQuery.in(rel, list));
             case SOME:
-              return RexSubQuery.some(root.rel, list,
+              return RexSubQuery.some(rel, list,
                   (SqlQuantifyOperator) call.getOperator());
             case ALL:
               return rexBuilder.makeCall(SqlStdOperatorTable.NOT,
-                  RexSubQuery.some(root.rel, list,
+                  RexSubQuery.some(rel, list,
                       negate((SqlQuantifyOperator) call.getOperator())));
             default:
               throw new AssertionError(kind);
@@ -5515,6 +5522,12 @@ public class SqlToRelConverter {
           query = Iterables.getOnlyElement(call.getOperandList());
           root = convertQueryRecursive(query, false, null);
           RelNode rel = root.rel;
+          // Fix the correlation namespaces and de-duplicate the correlation 
variables.
+          CorrelationUse correlationUse = getCorrelationUse(this, root.rel);
+          if (correlationUse != null) {
+            rel = correlationUse.r;
+          }
+
           while (rel instanceof Project
               || rel instanceof Sort
               && ((Sort) rel).fetch == null
@@ -5533,7 +5546,13 @@ public class SqlToRelConverter {
           call = (SqlCall) expr;
           query = Iterables.getOnlyElement(call.getOperandList());
           root = convertQueryRecursive(query, false, null);
-          return RexSubQuery.scalar(root.rel);
+          rel = root.rel;
+          // Fix the correlation namespaces and de-duplicate the correlation 
variables.
+          correlationUse = getCorrelationUse(this, root.rel);
+          if (correlationUse != null) {
+            rel = correlationUse.r;
+          }
+          return RexSubQuery.scalar(rel);
 
         case ARRAY_QUERY_CONSTRUCTOR:
           call = (SqlCall) expr;
diff --git a/core/src/test/java/org/apache/calcite/test/JdbcTest.java 
b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
index 5cf1177f7f..af4bf13ae8 100644
--- a/core/src/test/java/org/apache/calcite/test/JdbcTest.java
+++ b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
@@ -8270,6 +8270,75 @@ public class JdbcTest {
         .returns("EXPR$0=[1, 1.1]\n");
   }
 
+  /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6032";>[CALCITE-6032]
+   * NullPointerException in Reldecorrelator for a Multi level correlated 
subquery</a>. */
+  @Test void testMultiLevelDecorrelation() throws Exception {
+    String hsqldbMemUrl = "jdbc:hsqldb:mem:.";
+    Connection baseConnection = DriverManager.getConnection(hsqldbMemUrl);
+    Statement baseStmt = baseConnection.createStatement();
+    baseStmt.execute("create table invoice (inv_id integer, col1\n"
+        + "integer, inv_amt integer)");
+    baseStmt.execute("create table item(item_id integer, item_amt\n"
+        + "integer, item_col1 integer, item_col2 integer, item_col3\n"
+        + "integer,item_col4 integer )");
+    baseStmt.execute("INSERT INTO invoice VALUES (1, 1, 1)");
+    baseStmt.execute("INSERT INTO invoice VALUES (2, 2, 2)");
+    baseStmt.execute("INSERT INTO invoice VALUES (3, 3, 3)");
+    baseStmt.execute("INSERT INTO item values (1, 1, 1, 1, 1, 1)");
+    baseStmt.execute("INSERT INTO item values (2, 2, 2, 2, 2, 2)");
+    baseStmt.close();
+    baseConnection.commit();
+
+    Properties info = new Properties();
+    info.put("model",
+        "inline:"
+            + "{\n"
+            + "  version: '1.0',\n"
+            + "  defaultSchema: 'BASEJDBC',\n"
+            + "  schemas: [\n"
+            + "     {\n"
+            + "       type: 'jdbc',\n"
+            + "       name: 'BASEJDBC',\n"
+            + "       jdbcDriver: '" + jdbcDriver.class.getName() + "',\n"
+            + "       jdbcUrl: '" + hsqldbMemUrl + "',\n"
+            + "       jdbcCatalog: null,\n"
+            + "       jdbcSchema: null\n"
+            + "     }\n"
+            + "  ]\n"
+            + "}");
+
+    Connection calciteConnection =
+        DriverManager.getConnection("jdbc:calcite:", info);
+
+    String statement = "SELECT Sum(invoice.inv_amt * (\n"
+        + "              SELECT max(mainrate.item_id + mainrate.item_amt)\n"
+        + "              FROM   item AS mainrate\n"
+        + "              WHERE  mainrate.item_col1 is not null\n"
+        + "              AND    mainrate.item_col2 is not null\n"
+        + "              AND    mainrate.item_col3 = invoice.col1\n"
+        + "              AND    mainrate.item_col4 = (\n"
+        + "                        SELECT max(cr.item_col4)\n"
+        + "                        FROM   item AS cr\n"
+        + "                        WHERE  cr.item_col3 = mainrate.item_col3\n"
+        + "                                AND    cr.item_col1 =\n"
+        + "mainrate.item_col1\n"
+        + "                                AND    cr.item_col2 =\n"
+        + "mainrate.item_col2 \n"
+        + "                                AND    cr.item_col4 <=\n"
+        + "invoice.inv_id))) AS invamount,\n"
+        + "count(*)       AS invcount\n"
+        + "FROM    invoice\n"
+        + "WHERE  invoice.inv_amt < 10 AND  invoice.inv_amt > 0";
+    ResultSet rs = 
calciteConnection.prepareStatement(statement).executeQuery();
+    assert rs.next();
+    assertEquals(rs.getInt(1), 10);
+    assertEquals(rs.getInt(2), 3);
+    assert !rs.next();
+    rs.close();
+    calciteConnection.close();
+  }
+
   /** Test case for
    * <a 
href="https://issues.apache.org/jira/browse/CALCITE-5414";>[CALCITE-5414]</a>
    * Convert between standard Gregorian and proleptic Gregorian calendars for

Reply via email to