This is an automated email from the ASF dual-hosted git repository.

hyuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/master by this push:
     new e043640  [CALCITE-2004] Push join predicate down into inner relation 
for lateral join
e043640 is described below

commit e043640ef1c3eef36bab457fc98303a7627b5505
Author: Haisheng Yuan <h.y...@alibaba-inc.com>
AuthorDate: Fri Mar 8 11:33:27 2019 -0600

    [CALCITE-2004] Push join predicate down into inner relation for lateral join
    
    Before this patch, the join predicate was put on top of LogicalCorrelate, 
which
    is wrong plan and may cause wrong result for left outer lateral join.
---
 .../apache/calcite/sql2rel/SqlToRelConverter.java  |  49 ++++++-
 .../org/apache/calcite/test/RelOptRulesTest.java   |   2 +-
 .../apache/calcite/test/SqlToRelConverterTest.java |  85 ++++++++++++
 .../org/apache/calcite/test/TableFunctionTest.java |  14 ++
 .../org/apache/calcite/test/RelOptRulesTest.xml    |  36 ++---
 .../apache/calcite/test/SqlToRelConverterTest.xml  | 145 ++++++++++++++++++++-
 6 files changed, 305 insertions(+), 26 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java 
b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
index f462804..698a028 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java
@@ -188,6 +188,7 @@ import java.math.BigDecimal;
 import java.util.AbstractList;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
+import java.util.BitSet;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Deque;
@@ -2425,6 +2426,30 @@ public class SqlToRelConverter {
     }
   }
 
+  /** Shuttle that replace outer {@link RexInputRef} with
+   * {@link RexFieldAccess}, and adjust {@code offset} to
+   * each inner {@link RexInputRef} in the lateral join
+   * condition. */
+  private static class RexAccessShuttle extends RexShuttle {
+    private final RexBuilder builder;
+    private final RexCorrelVariable rexCorrel;
+    private final BitSet varCols = new BitSet();
+
+    RexAccessShuttle(RexBuilder builder, RexCorrelVariable rexCorrel) {
+      this.builder = builder;
+      this.rexCorrel = rexCorrel;
+    }
+
+    @Override public RexNode visitInputRef(RexInputRef input) {
+      int i = input.getIndex() - rexCorrel.getType().getFieldCount();
+      if (i < 0) {
+        varCols.set(input.getIndex());
+        return builder.makeFieldAccess(rexCorrel, input.getIndex());
+      }
+      return builder.makeInputRef(input.getType(), i);
+    }
+  }
+
   protected RelNode createJoin(
       Blackboard bb,
       RelNode leftRel,
@@ -2435,13 +2460,29 @@ public class SqlToRelConverter {
 
     final CorrelationUse p = getCorrelationUse(bb, rightRel);
     if (p != null) {
-      LogicalCorrelate corr = LogicalCorrelate.create(leftRel, p.r,
-          p.id, p.requiredColumns, SemiJoinType.of(joinType));
+      RelNode innerRel = p.r;
+      ImmutableBitSet requiredCols = p.requiredColumns;
+
       if (!joinCond.isAlwaysTrue()) {
         final RelFactories.FilterFactory factory =
             RelFactories.DEFAULT_FILTER_FACTORY;
-        return factory.createFilter(corr, joinCond);
-      }
+        final RexCorrelVariable rexCorrel =
+            (RexCorrelVariable) rexBuilder.makeCorrel(
+                leftRel.getRowType(), p.id);
+        final RexAccessShuttle shuttle =
+            new RexAccessShuttle(rexBuilder, rexCorrel);
+
+        // Replace outer RexInputRef with RexFieldAccess,
+        // and push lateral join predicate into inner child
+        final RexNode newCond = joinCond.accept(shuttle);
+        innerRel = factory.createFilter(p.r, newCond);
+        requiredCols = ImmutableBitSet
+            .fromBitSet(shuttle.varCols)
+            .union(p.requiredColumns);
+      }
+
+      LogicalCorrelate corr = LogicalCorrelate.create(leftRel, innerRel,
+          p.id, requiredCols, SemiJoinType.of(joinType));
       return corr;
     }
 
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index d033818..600f4b7 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -3473,7 +3473,7 @@ public class RelOptRulesTest extends RelOptTestBase {
         + "  select n2.SAL\n"
         + "  from EMPNULLABLES_20 n2\n"
         + "  where n1.SAL = n2.SAL or n1.SAL = 4)";
-    sql(sql).withDecorrelation(true).with(program).check();
+    sql(sql).withDecorrelation(true).with(program).checkUnchanged();
   }
 
   /** Test case for
diff --git 
a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java 
b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
index 9e21272..499df02 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlToRelConverterTest.java
@@ -3120,6 +3120,91 @@ public class SqlToRelConverterTest extends 
SqlToRelTestBase {
   }
 
   /**
+   * Tests left join lateral with using
+   */
+  @Test public void testLeftJoinLateral1() {
+    final String sql = "select * from (values 4) as t(c)\n"
+        + " left join lateral\n"
+        + " (select c,a*c from (values 2) as s(a)) as r(d,c)\n"
+        + " using(c)";
+    sql(sql).ok();
+  }
+
+  /**
+   * Tests left join lateral with natural join
+   */
+  @Test public void testLeftJoinLateral2() {
+    final String sql = "select * from (values 4) as t(c)\n"
+        + " natural left join lateral\n"
+        + " (select c,a*c from (values 2) as s(a)) as r(d,c)";
+    sql(sql).ok();
+  }
+
+  /**
+   * Tests left join lateral with on condition
+   */
+  @Test public void testLeftJoinLateral3() {
+    final String sql = "select * from (values 4) as t(c)\n"
+        + " left join lateral\n"
+        + " (select c,a*c from (values 2) as s(a)) as r(d,c)\n"
+        + " on t.c=r.c";
+    sql(sql).ok();
+  }
+
+  /**
+   * Tests left join lateral with multiple columns from outer
+   */
+  @Test public void testLeftJoinLateral4() {
+    final String sql = "select * from (values (4,5)) as t(c,d)\n"
+        + " left join lateral\n"
+        + " (select c,a*c from (values 2) as s(a)) as r(d,c)\n"
+        + " on t.c+t.d=r.c";
+    sql(sql).ok();
+  }
+
+  /**
+   * Tests left join lateral with correlate variable coming
+   * from one level up join scope
+   */
+  @Test public void testLeftJoinLateral5() {
+    final String sql = "select * from (values 4) as t (c)\n"
+        + "left join lateral\n"
+        + "  (select f1+b1 from (values 2) as foo(f1)\n"
+        + "    join\n"
+        + "  (select c+1 from (values 3)) as bar(b1)\n"
+        + "  on f1=b1)\n"
+        + "as r(n) on c=n";
+    sql(sql).ok();
+  }
+
+  /**
+   * Tests cross join lateral with multiple columns from outer
+   */
+  @Test public void testCrossJoinLateral1() {
+    final String sql = "select * from (values (4,5)) as t(c,d)\n"
+        + " cross join lateral\n"
+        + " (select c,a*c as f from (values 2) as s(a)\n"
+        + " where c+d=a*c)";
+    sql(sql).ok();
+  }
+
+  /**
+   * Tests cross join lateral with correlate variable coming
+   * from one level up join scope
+   */
+  @Test public void testCrossJoinLateral2() {
+    final String sql = "select * from (values 4) as t (c)\n"
+        + "cross join lateral\n"
+        + "(select * from (\n"
+        + "  select f1+b1 from (values 2) as foo(f1)\n"
+        + "    join\n"
+        + "  (select c+1 from (values 3)) as bar(b1)\n"
+        + "  on f1=b1\n"
+        + ") as r(n) where c=n)";
+    sql(sql).ok();
+  }
+
+  /**
    * Visitor that checks that every {@link RelNode} in a tree is valid.
    *
    * @see RelNode#isValid(Litmus, RelNode.Context)
diff --git a/core/src/test/java/org/apache/calcite/test/TableFunctionTest.java 
b/core/src/test/java/org/apache/calcite/test/TableFunctionTest.java
index cddffe3..98cc725 100644
--- a/core/src/test/java/org/apache/calcite/test/TableFunctionTest.java
+++ b/core/src/test/java/org/apache/calcite/test/TableFunctionTest.java
@@ -432,6 +432,20 @@ public class TableFunctionTest {
   }
 
   /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-2004";>[CALCITE-2004]
+   * Wrong plan generated for left outer apply with table function</a>. */
+  @Test public void testLeftOuterApply() {
+    final String sql = "select *\n"
+        + "from (values 4) as t (c)\n"
+        + "left join lateral table(\"s\".\"fibonacci2\"(c)) as R(n) on c=n";
+    with()
+        .with(CalciteConnectionProperty.CONFORMANCE,
+            SqlConformanceEnum.LENIENT)
+        .query(sql)
+        .returnsUnordered("C=4; N=null");
+  }
+
+  /** Test case for
    * <a 
href="https://issues.apache.org/jira/browse/CALCITE-2382";>[CALCITE-2382]
    * Sub-query lateral joined to table function</a>. */
   @Test public void testInlineViewLateralTableFunction() throws SQLException {
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index adfd110..ffa5377 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -2764,9 +2764,9 @@ IN (select e.deptno from sales.emp e where e.deptno = 
d.deptno or e.deptno = 4)]
         <Resource name="planBefore">
             <![CDATA[
 LogicalProject(DEPTNO=[$7])
-  LogicalFilter(condition=[=($7, $9)])
-    LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{7}])
-      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+  LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{7}])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalFilter(condition=[=($cor0.DEPTNO, $0)])
       LogicalAggregate(group=[{0}])
         LogicalProject(DEPTNO=[$7])
           LogicalFilter(condition=[OR(=($7, $cor0.DEPTNO), =($7, 4))])
@@ -2786,31 +2786,31 @@ where n1.SAL IN (
         <Resource name="planBefore">
             <![CDATA[
 LogicalProject(SAL=[$5])
-  LogicalJoin(condition=[AND(=($5, $9), =($5, $8))], joinType=[inner])
+  LogicalJoin(condition=[=($5, $9)], joinType=[inner])
     LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], 
SAL=[$5], COMM=[$6], SLACKER=[$8])
       LogicalFilter(condition=[AND(=($7, 20), >($5, 1000))])
         LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
-    LogicalAggregate(group=[{0, 1}])
-      LogicalProject(SAL=[$5], SAL0=[$8])
-        LogicalJoin(condition=[OR(=($8, $5), =($8, 4))], joinType=[inner])
-          LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], 
HIREDATE=[$4], SAL=[$5], COMM=[$6], SLACKER=[$8])
-            LogicalFilter(condition=[AND(=($7, 20), >($5, 1000))])
-              LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
-          LogicalAggregate(group=[{0}])
-            LogicalProject(SAL=[$5])
+    LogicalFilter(condition=[=($1, $0)])
+      LogicalAggregate(group=[{0, 1}])
+        LogicalProject(SAL=[$5], SAL0=[$8])
+          LogicalJoin(condition=[OR(=($8, $5), =($8, 4))], joinType=[inner])
+            LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], 
HIREDATE=[$4], SAL=[$5], COMM=[$6], SLACKER=[$8])
               LogicalFilter(condition=[AND(=($7, 20), >($5, 1000))])
                 LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
+            LogicalAggregate(group=[{0}])
+              LogicalProject(SAL=[$5])
+                LogicalFilter(condition=[AND(=($7, 20), >($5, 1000))])
+                  LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
 ]]>
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
 LogicalProject(SAL=[$5])
-  LogicalJoin(condition=[AND(=($5, $9), =($5, $8))], joinType=[inner])
-    LogicalFilter(condition=[OR(IS NOT NULL($5), =($5, 4))])
-      LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], 
HIREDATE=[$4], SAL=[$5], COMM=[$6], SLACKER=[$8])
-        LogicalFilter(condition=[AND(=($7, 20), >($5, 1000))])
-          LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
-    LogicalFilter(condition=[AND(OR(IS NOT NULL($0), =($0, 4)), OR(=($0, $1), 
=($0, 4)), OR(IS NOT NULL($1), =($1, 4)))])
+  LogicalJoin(condition=[=($5, $9)], joinType=[inner])
+    LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], 
SAL=[$5], COMM=[$6], SLACKER=[$8])
+      LogicalFilter(condition=[AND(=($7, 20), >($5, 1000))])
+        LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
+    LogicalFilter(condition=[=($1, $0)])
       LogicalAggregate(group=[{0, 1}])
         LogicalProject(SAL=[$5], SAL0=[$8])
           LogicalJoin(condition=[OR(=($8, $5), =($8, 4))], joinType=[inner])
diff --git 
a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml 
b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
index 8496053..3d343c6 100644
--- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
@@ -277,9 +277,9 @@ for system_time as of orders.rowtime on orders.productid = 
products_temporal.pro
             <![CDATA[
 LogicalDelta
   LogicalProject(ROWTIME=[$0], PRODUCTID=[$1], ORDERID=[$2], PRODUCTID0=[$3], 
NAME=[$4], SUPPLIERID=[$5], SYS_START=[$6], SYS_END=[$7])
-    LogicalFilter(condition=[=($1, $3)])
-      LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{0}])
-        LogicalTableScan(table=[[CATALOG, SALES, ORDERS]])
+    LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{0, 1}])
+      LogicalTableScan(table=[[CATALOG, SALES, ORDERS]])
+      LogicalFilter(condition=[=($cor0.PRODUCTID, $0)])
         LogicalSnapshot(period=[$cor0.ROWTIME])
           LogicalTableScan(table=[[CATALOG, SALES, PRODUCTS_TEMPORAL]])
 ]]>
@@ -5726,4 +5726,143 @@ LogicalSort(sort0=[$0], dir0=[ASC])
 ]]>
         </Resource>
     </TestCase>
+    <TestCase name="testLeftJoinLateral1">
+        <Resource name="sql">
+            <![CDATA[select * from (values 4) as t (c)
+            left join lateral
+            (select c,a*c from (values 2) as s(a)) as r(d,c)
+            using(c)]]>
+        </Resource>
+        <Resource name="plan">
+            <![CDATA[
+LogicalProject(C=[$0], D=[$1])
+  LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])
+    LogicalValues(tuples=[[{ 4 }]])
+    LogicalFilter(condition=[=($cor0.C, $1)])
+      LogicalProject(C=[$cor0.C], EXPR$1=[*($0, $cor0.C)])
+        LogicalValues(tuples=[[{ 2 }]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testLeftJoinLateral2">
+        <Resource name="sql">
+            <![CDATA[select * from (values 4) as t (c)
+            natural left join lateral
+            (select c,a*c from (values 2) as s(a)) as r(d,c)]]>
+        </Resource>
+        <Resource name="plan">
+            <![CDATA[
+LogicalProject(C=[$0], D=[$1])
+  LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])
+    LogicalValues(tuples=[[{ 4 }]])
+    LogicalFilter(condition=[=($cor0.C, $1)])
+      LogicalProject(C=[$cor0.C], EXPR$1=[*($0, $cor0.C)])
+        LogicalValues(tuples=[[{ 2 }]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testLeftJoinLateral3">
+        <Resource name="sql">
+            <![CDATA[select * from (values 4) as t (c)
+            left join lateral
+            (select c,a*c from (values 2) as s(a)) as r(d,c)
+            on t.c=r.c]]>
+        </Resource>
+        <Resource name="plan">
+            <![CDATA[
+LogicalProject(C=[$0], D=[$1], C0=[$2])
+  LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])
+    LogicalValues(tuples=[[{ 4 }]])
+    LogicalFilter(condition=[=($cor0.C, $1)])
+      LogicalProject(C=[$cor0.C], EXPR$1=[*($0, $cor0.C)])
+        LogicalValues(tuples=[[{ 2 }]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testLeftJoinLateral4">
+        <Resource name="sql">
+            <![CDATA[select * from (values (4,5)) as t(c,d)
+            left join lateral
+            (select c,a*c from (values 2) as s(a)) as r(d,c)
+            on t.c+t.d=r.c]]>
+        </Resource>
+        <Resource name="plan">
+            <![CDATA[
+LogicalProject(C=[$0], D=[$1], D0=[$2], C0=[$3])
+  LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0, 
1}])
+    LogicalValues(tuples=[[{ 4, 5 }]])
+    LogicalFilter(condition=[=(+($cor0.C, $cor0.D), $1)])
+      LogicalProject(C=[$cor0.C], EXPR$1=[*($0, $cor0.C)])
+        LogicalValues(tuples=[[{ 2 }]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testLeftJoinLateral5">
+        <Resource name="sql">
+            <![CDATA[select * from (values 4) as t (c)
+        left join lateral
+          (select f1+b1 from (values 2) as foo(f1)
+            join
+          (select c+1 from (values 3)) as bar(b1)
+          on f1=b1)
+        as r(n) on c=n]]>
+        </Resource>
+        <Resource name="plan">
+            <![CDATA[
+LogicalProject(C=[$0], N=[$1])
+  LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])
+    LogicalValues(tuples=[[{ 4 }]])
+    LogicalFilter(condition=[=($cor0.C, $0)])
+      LogicalProject(EXPR$0=[+($0, $1)])
+        LogicalJoin(condition=[=($0, $1)], joinType=[inner])
+          LogicalValues(tuples=[[{ 2 }]])
+          LogicalProject(EXPR$0=[+($cor0.C, 1)])
+            LogicalValues(tuples=[[{ 3 }]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testCrossJoinLateral1">
+        <Resource name="sql">
+            <![CDATA[select * from (values (4,5)) as t(c,d)
+         cross join lateral
+         (select c,a*c as f from (values 2) as s(a)
+         where c+d=a*c)]]>
+        </Resource>
+        <Resource name="plan">
+            <![CDATA[
+LogicalProject(C=[$0], D=[$1], C0=[$2], F=[$3])
+  LogicalCorrelate(correlation=[$cor3], joinType=[inner], requiredColumns=[{0, 
1}])
+    LogicalValues(tuples=[[{ 4, 5 }]])
+    LogicalProject(C=[$cor3.C], F=[*($0, $cor3.C)])
+      LogicalFilter(condition=[=(+($cor3.C, $cor3.D), *($0, $cor3.C))])
+        LogicalValues(tuples=[[{ 2 }]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testCrossJoinLateral2">
+        <Resource name="sql">
+            <![CDATA[select * from (values 4) as t (c)
+        cross join lateral
+        (select * from (
+          select f1+b1 from (values 2) as foo(f1)
+            join
+          (select c+1 from (values 3)) as bar(b1)
+          on f1=b1
+        ) as r(n) where c=n)]]>
+        </Resource>
+        <Resource name="plan">
+            <![CDATA[
+LogicalProject(C=[$0], N=[$1])
+  LogicalCorrelate(correlation=[$cor1], joinType=[inner], 
requiredColumns=[{0}])
+    LogicalValues(tuples=[[{ 4 }]])
+    LogicalProject(N=[$0])
+      LogicalFilter(condition=[=($cor1.C, $0)])
+        LogicalProject(EXPR$0=[+($0, $1)])
+          LogicalJoin(condition=[=($0, $1)], joinType=[inner])
+            LogicalValues(tuples=[[{ 2 }]])
+            LogicalProject(EXPR$0=[+($cor1.C, 1)])
+              LogicalValues(tuples=[[{ 3 }]])
+]]>
+        </Resource>
+    </TestCase>
 </Root>

Reply via email to