This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 420e6878c68 [SPARK-43780][SQL] Support correlated references in join 
predicates for scalar and lateral subqueries
420e6878c68 is described below

commit 420e6878c68f9ff68171cc6d43ca95d69c49eac4
Author: Andrey Gubichev <andrey.gubic...@databricks.com>
AuthorDate: Tue Aug 15 09:52:12 2023 +0800

    [SPARK-43780][SQL] Support correlated references in join predicates for 
scalar and lateral subqueries
    
    ### What changes were proposed in this pull request?
    
    This PR adds support to subqueries that involve joins with correlated 
references in join predicates, e.g.
    
    ```
    select * from t0 join lateral (select * from t1 join t2 on t1a = t2a and 
t1a = t0a);
    ```
    
    (full example in https://issues.apache.org/jira/browse/SPARK-43780)
    
    Currently we only handle scalar and lateral subqueries.
    
    ### Why are the changes needed?
    
    This is a valid SQL that is not yet supported by Spark SQL.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, previously unsupported queries become supported.
    
    ### How was this patch tested?
    
    Query and unit tests
    
    Closes #41301 from agubichev/spark-43780-corr-predicate.
    
    Authored-by: Andrey Gubichev <andrey.gubic...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/analysis/CheckAnalysis.scala      |   1 +
 .../catalyst/optimizer/DecorrelateInnerQuery.scala |  92 ++++++++++++--
 .../org/apache/spark/sql/internal/SQLConf.scala    |  10 ++
 .../optimizer/DecorrelateInnerQuerySuite.scala     | 133 ++++++++++++++++++++-
 .../analyzer-results/join-lateral.sql.out          |  66 ++++++++++
 .../scalar-subquery-predicate.sql.out              |  81 +++++++++++++
 .../scalar-subquery/scalar-subquery-set-op.sql.out |  80 +++++++++++++
 .../resources/sql-tests/inputs/join-lateral.sql    |   5 +
 .../scalar-subquery/scalar-subquery-predicate.sql  |  20 ++++
 .../scalar-subquery/scalar-subquery-set-op.sql     |  20 ++++
 .../sql-tests/results/join-lateral.sql.out         |  39 ++++++
 .../scalar-subquery-predicate.sql.out              |  37 ++++++
 .../scalar-subquery/scalar-subquery-set-op.sql.out |  32 +++++
 13 files changed, 605 insertions(+), 11 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 48c38a9bd4c..c7346809f3f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -1173,6 +1173,7 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
     def canHostOuter(plan: LogicalPlan): Boolean = plan match {
       case _: Filter => true
       case _: Project => usingDecorrelateInnerQueryFramework
+      case _: Join => usingDecorrelateInnerQueryFramework
       case _ => false
     }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
index 86fa78e96a5..a3e264579f4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
@@ -804,11 +804,73 @@ object DecorrelateInnerQuery extends PredicateHelper {
             (d.copy(child = newChild), joinCond, outerReferenceMap)
 
           case j @ Join(left, right, joinType, condition, _) =>
-            val outerReferences = collectOuterReferences(j.expressions)
-            // Join condition containing outer references is not supported.
-            assert(outerReferences.isEmpty, s"Correlated column is not allowed 
in join: $j")
-            val newOuterReferences = parentOuterReferences ++ outerReferences
-            val shouldPushToLeft = joinType match {
+            // Given 'condition', computes the tuple of
+            // (correlated, uncorrelated, equalityCond, predicates, 
equivalences).
+            // 'correlated' and 'uncorrelated' are the conjuncts with (resp. 
without)
+            // outer (correlated) references. Furthermore, correlated 
conjuncts are split
+            // into 'equalityCond' (those that are equalities) and all rest 
('predicates').
+            // 'equivalences' track equivalent attributes given 'equalityCond'.
+            // The split is only performed if 'shouldDecorrelatePredicates' is 
true.
+            // The input parameter 'isInnerJoin' is set to true for INNER 
joins and helps
+            // determine whether some predicates can be lifted up from the 
join (this is only
+            // valid for inner joins).
+            // Example: For a 'condition' A = outer(X) AND B > outer(Y) AND C 
= D, the output
+            // would be:
+            // correlated = (A = outer(X), B > outer(Y))
+            // uncorrelated = (C = D)
+            // equalityCond = (A = outer(X))
+            // predicates = (B > outer(Y))
+            // equivalences: (A -> outer(X))
+            def splitCorrelatedPredicate(
+                condition: Option[Expression],
+                isInnerJoin: Boolean,
+                shouldDecorrelatePredicates: Boolean):
+            (Seq[Expression], Seq[Expression], Seq[Expression],
+              Seq[Expression], AttributeMap[Attribute]) = {
+              // Similar to Filters above, we split the join condition (if 
present) into correlated
+              // and uncorrelated predicates, and separately handle joins 
under set and aggregation
+              // operations.
+              if (shouldDecorrelatePredicates) {
+                val conditions =
+                  if (condition.isDefined) 
splitConjunctivePredicates(condition.get)
+                  else Seq.empty[Expression]
+                val (correlated, uncorrelated) = 
conditions.partition(containsOuter)
+                var equivalences =
+                  if (underSetOp) AttributeMap.empty[Attribute]
+                  else collectEquivalentOuterReferences(correlated)
+                var (equalityCond, predicates) =
+                  if (underSetOp) (Seq.empty[Expression], correlated)
+                  else correlated.partition(canPullUpOverAgg)
+                // Fully preserve the join predicate for non-inner joins.
+                if (!isInnerJoin) {
+                  predicates = correlated
+                  equalityCond = Seq.empty[Expression]
+                  equivalences = AttributeMap.empty[Attribute]
+                }
+                (correlated, uncorrelated, equalityCond, predicates, 
equivalences)
+              } else {
+                (Seq.empty[Expression],
+                  if (condition.isEmpty) Seq.empty[Expression] else 
Seq(condition.get),
+                  Seq.empty[Expression],
+                  Seq.empty[Expression],
+                  AttributeMap.empty[Attribute])
+              }
+            }
+
+            val shouldDecorrelatePredicates =
+              SQLConf.get.getConf(SQLConf.DECORRELATE_JOIN_PREDICATE_ENABLED)
+            if (!shouldDecorrelatePredicates) {
+              val outerReferences = collectOuterReferences(j.expressions)
+              // Join condition containing outer references is not supported.
+              assert(outerReferences.isEmpty, s"Correlated column is not 
allowed in join: $j")
+            }
+            val (correlated, uncorrelated, equalityCond, predicates, 
equivalences) =
+              splitCorrelatedPredicate(condition, joinType == Inner, 
shouldDecorrelatePredicates)
+            val outerReferences = collectOuterReferences(j.expressions) ++
+              collectOuterReferences(predicates)
+            val newOuterReferences =
+              parentOuterReferences ++ outerReferences -- equivalences.keySet
+            var shouldPushToLeft = joinType match {
               case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true
               case _ => hasOuterReferences(left)
             }
@@ -816,6 +878,14 @@ object DecorrelateInnerQuery extends PredicateHelper {
               case RightOuter | FullOuter => true
               case _ => hasOuterReferences(right)
             }
+            if (shouldDecorrelatePredicates && !shouldPushToLeft && 
!shouldPushToRight
+              && !predicates.isEmpty) {
+              // Neither left nor right children of the join have 
correlations, but the join
+              // predicate does, and the correlations can not be replaced via 
equivalences.
+              // Introduce a domain join on the left side of the join
+              // (chosen arbitrarily) to provide values for the correlated 
attribute reference.
+              shouldPushToLeft = true;
+            }
             val (newLeft, leftJoinCond, leftOuterReferenceMap) = if 
(shouldPushToLeft) {
               decorrelate(left, newOuterReferences, aggregated, underSetOp)
             } else {
@@ -826,8 +896,13 @@ object DecorrelateInnerQuery extends PredicateHelper {
             } else {
               (right, Nil, AttributeMap.empty[Attribute])
             }
-            val newOuterReferenceMap = leftOuterReferenceMap ++ 
rightOuterReferenceMap
-            val newJoinCond = leftJoinCond ++ rightJoinCond
+            val newOuterReferenceMap = leftOuterReferenceMap ++ 
rightOuterReferenceMap ++
+              equivalences
+            val newCorrelated =
+              if (shouldDecorrelatePredicates) {
+                replaceOuterReferences(correlated, newOuterReferenceMap)
+              } else Seq.empty[Expression]
+            val newJoinCond = leftJoinCond ++ rightJoinCond ++ equalityCond
             // If we push the dependent join to both sides, we can augment the 
join condition
             // such that both sides are matched on the domain attributes. For 
example,
             // - Left Map: {outer(c1) = c1}
@@ -836,7 +911,8 @@ object DecorrelateInnerQuery extends PredicateHelper {
             val augmentedConditions = leftOuterReferenceMap.flatMap {
               case (outer, inner) => 
rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _))
             }
-            val newCondition = (condition ++ 
augmentedConditions).reduceOption(And)
+            val newCondition = (newCorrelated ++ uncorrelated
+              ++ augmentedConditions).reduceOption(And)
             val newJoin = j.copy(left = newLeft, right = newRight, condition = 
newCondition)
             (newJoin, newJoinCond, newOuterReferenceMap)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index e4f335a9a08..ced3f3458c0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4370,6 +4370,16 @@ object SQLConf {
       .checkValue(_ >= 0, "The threshold of cached local relations must not be 
negative")
       .createWithDefault(64 * 1024 * 1024)
 
+  val DECORRELATE_JOIN_PREDICATE_ENABLED =
+    buildConf("spark.sql.optimizer.decorrelateJoinPredicate.enabled")
+      .internal()
+      .doc("Decorrelate scalar and lateral subqueries with correlated 
references in join " +
+        "predicates. This configuration is only effective when " +
+        "'${DECORRELATE_INNER_QUERY_ENABLED.key}' is true.")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(true)
+
   /**
    * Holds information about keys that have been deprecated.
    *
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala
index 304f7de4c6a..21ac8849fe2 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala
@@ -35,10 +35,13 @@ class DecorrelateInnerQuerySuite extends PlanTest {
   val a3 = AttributeReference("a3", IntegerType)()
   val b3 = AttributeReference("b3", IntegerType)()
   val c3 = AttributeReference("c3", IntegerType)()
+  val a4 = AttributeReference("a4", IntegerType)()
+  val b4 = AttributeReference("b4", IntegerType)()
   val t0 = OneRowRelation()
   val testRelation = LocalRelation(a, b, c)
   val testRelation2 = LocalRelation(x, y, z)
   val testRelation3 = LocalRelation(a3, b3, c3)
+  val testRelation4 = LocalRelation(a4, b4)
 
   private def hasOuterReferences(plan: LogicalPlan): Boolean = {
     plan.exists(_.expressions.exists(SubExprUtils.containsOuter))
@@ -198,12 +201,15 @@ class DecorrelateInnerQuerySuite extends PlanTest {
     val innerPlan =
       Join(
         testRelation.as("t1"),
-        Filter(OuterReference(y) === 3, testRelation),
+        Filter(OuterReference(y) === b3, testRelation3),
         Inner,
         Some(OuterReference(x) === a),
         JoinHint.NONE)
-    val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, 
outerPlan.select()) }
-    assert(error.getMessage.contains("Correlated column is not allowed in 
join"))
+    val correctAnswer =
+      Join(
+        testRelation.as("t1"), testRelation3,
+        Inner, Some(a === a), JoinHint.NONE)
+    check(innerPlan, outerPlan, correctAnswer, Seq(b3 === y, x === a))
   }
 
   test("correlated values in project") {
@@ -454,4 +460,125 @@ class DecorrelateInnerQuerySuite extends PlanTest {
             DomainJoin(Seq(x), testRelation))))
     check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x))
   }
+
+  test("SPARK-43780: aggregation in subquery with correlated equi-join") {
+    // Join in the subquery is on equi-predicates, so all the correlated 
references can be
+    // substituted by equivalent ones from the outer query, and domain join is 
not needed.
+    val outerPlan = testRelation
+    val innerPlan =
+      Aggregate(
+        Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
+        Project(Seq(x, y, a3, b3),
+          Join(testRelation2, testRelation3, Inner,
+            Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)))
+
+    val correctAnswer =
+      Aggregate(
+        Seq(y), Seq(Alias(count(Literal(1)), "a")(), y),
+        Project(Seq(x, y, a3, b3),
+          Join(testRelation2, testRelation3, Inner, Some(And(y === y, x === 
a3)), JoinHint.NONE)))
+    check(innerPlan, outerPlan, correctAnswer, Seq(y === a))
+  }
+
+  test("SPARK-43780: aggregation in subquery with correlated non-equi-join") {
+    // Join in the subquery is on non-equi-predicate, so we introduce a 
DomainJoin.
+    val outerPlan = testRelation
+    val innerPlan =
+      Aggregate(
+        Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
+        Project(Seq(x, y, a3, b3),
+          Join(testRelation2, testRelation3, Inner,
+            Some(And(x === a3, y > OuterReference(a))), JoinHint.NONE)))
+    val correctAnswer =
+      Aggregate(
+        Seq(a), Seq(Alias(count(Literal(1)), "a")(), a),
+        Project(Seq(x, y, a3, b3, a),
+          Join(
+            DomainJoin(Seq(a), testRelation2),
+            testRelation3, Inner, Some(And(x === a3, y > a)), JoinHint.NONE)))
+    check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
+  }
+
+  test("SPARK-43780: aggregation in subquery with correlated left join") {
+    // Join in the subquery is on equi-predicates, so all the correlated 
references can be
+    // substituted by equivalent ones from the outer query, and domain join is 
not needed.
+    val outerPlan = testRelation
+    val innerPlan =
+      Aggregate(
+        Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
+        Project(Seq(x, y, a3, b3),
+          Join(testRelation2, testRelation3, LeftOuter,
+            Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)))
+
+    val correctAnswer =
+      Aggregate(
+        Seq(a), Seq(Alias(count(Literal(1)), "a")(), a),
+        Project(Seq(x, y, a3, b3, a),
+          Join(DomainJoin(Seq(a), testRelation2), testRelation3, LeftOuter,
+            Some(And(y === a, x === a3)), JoinHint.NONE)))
+    check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
+  }
+
+  test("SPARK-43780: aggregation in subquery with correlated left join, " +
+    "correlation over right side") {
+    // Same as above, but the join predicate connects the outer reference and 
the column from the
+    // right (optional) side of the left join. Domain join is still not needed.
+    val outerPlan = testRelation
+    val innerPlan =
+      Aggregate(
+        Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
+        Project(Seq(x, y, a3, b3),
+          Join(testRelation2, testRelation3, LeftOuter,
+            Some(And(x === a3, b3 === OuterReference(b))), JoinHint.NONE)))
+
+    val correctAnswer =
+      Aggregate(
+        Seq(b), Seq(Alias(count(Literal(1)), "a")(), b),
+        Project(Seq(x, y, a3, b3, b),
+          Join(DomainJoin(Seq(b), testRelation2), testRelation3, LeftOuter,
+            Some(And(b === b3, x === a3)), JoinHint.NONE)))
+    check(innerPlan, outerPlan, correctAnswer, Seq(b <=> b))
+  }
+
+  test("SPARK-43780: correlated left join preserves the join predicates") {
+    // Left outer join preserves both predicates after being decorrelated.
+    val outerPlan = testRelation
+    val innerPlan =
+      Filter(
+        IsNotNull(c3),
+        Project(Seq(x, y, a3, b3, c3),
+          Join(testRelation2, testRelation3, LeftOuter,
+            Some(And(x === a3, b3 === OuterReference(b))), JoinHint.NONE)))
+
+    val correctAnswer =
+      Filter(
+        IsNotNull(c3),
+        Project(Seq(x, y, a3, b3, c3, b),
+          Join(DomainJoin(Seq(b), testRelation2), testRelation3, LeftOuter,
+            Some(And(x === a3, b === b3)), JoinHint.NONE)))
+    check(innerPlan, outerPlan, correctAnswer, Seq(b <=> b))
+  }
+
+  test("SPARK-43780: union all in subquery with correlated join") {
+    val outerPlan = testRelation
+    val innerPlan =
+      Union(
+        Seq(Project(Seq(x, b3),
+          Join(testRelation2, testRelation3, Inner,
+            Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)),
+          Project(Seq(a4, b4),
+            testRelation4)))
+    val correctAnswer =
+      Union(
+        Seq(Project(Seq(x, b3, a),
+          Project(Seq(x, b3, a),
+            Join(
+              DomainJoin(Seq(a), testRelation2),
+              testRelation3, Inner,
+              Some(And(x === a3, y === a)), JoinHint.NONE))),
+          Project(Seq(a4, b4, a),
+            DomainJoin(Seq(a),
+              Project(Seq(a4, b4), testRelation4)))))
+    check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
+  }
 }
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out
index 5225996c16b..2d1eebc65c6 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out
@@ -795,6 +795,72 @@ Project [c1#x, c2#x]
             +- LocalRelation [col1#x, col2#x]
 
 
+-- !query
+SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 = t4.c1 AND 
t2.c1 = t1.c1)
+-- !query analysis
+Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x]
++- LateralJoin lateral-subquery#x [c1#x], Inner
+   :  +- SubqueryAlias __auto_generated_subquery_name
+   :     +- Project [c1#x, c2#x, c1#x, c2#x]
+   :        +- Join Inner, ((c1#x = c1#x) AND (c1#x = outer(c1#x)))
+   :           :- SubqueryAlias spark_catalog.default.t2
+   :           :  +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
+   :           :     +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as 
int) AS c2#x]
+   :           :        +- LocalRelation [col1#x, col2#x]
+   :           +- SubqueryAlias spark_catalog.default.t4
+   :              +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
+   :                 +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as 
int) AS c2#x]
+   :                    +- LocalRelation [col1#x, col2#x]
+   +- SubqueryAlias spark_catalog.default.t1
+      +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
+         +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+            +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 != t4.c1 AND 
t2.c1 != t1.c1)
+-- !query analysis
+Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x]
++- LateralJoin lateral-subquery#x [c1#x], Inner
+   :  +- SubqueryAlias __auto_generated_subquery_name
+   :     +- Project [c1#x, c2#x, c1#x, c2#x]
+   :        +- Join Inner, (NOT (c1#x = c1#x) AND NOT (c1#x = outer(c1#x)))
+   :           :- SubqueryAlias spark_catalog.default.t2
+   :           :  +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
+   :           :     +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as 
int) AS c2#x]
+   :           :        +- LocalRelation [col1#x, col2#x]
+   :           +- SubqueryAlias spark_catalog.default.t4
+   :              +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
+   :                 +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as 
int) AS c2#x]
+   :                    +- LocalRelation [col1#x, col2#x]
+   +- SubqueryAlias spark_catalog.default.t1
+      +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
+         +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+            +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+SELECT * FROM t1 LEFT JOIN lateral (SELECT * FROM t4 LEFT JOIN t2 ON t2.c1 = 
t4.c1 AND t2.c1 = t1.c1)
+-- !query analysis
+Project [c1#x, c2#x, c1#x, c2#x, c1#x, c2#x]
++- LateralJoin lateral-subquery#x [c1#x], LeftOuter
+   :  +- SubqueryAlias __auto_generated_subquery_name
+   :     +- Project [c1#x, c2#x, c1#x, c2#x]
+   :        +- Join LeftOuter, ((c1#x = c1#x) AND (c1#x = outer(c1#x)))
+   :           :- SubqueryAlias spark_catalog.default.t4
+   :           :  +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
+   :           :     +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as 
int) AS c2#x]
+   :           :        +- LocalRelation [col1#x, col2#x]
+   :           +- SubqueryAlias spark_catalog.default.t2
+   :              +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
+   :                 +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as 
int) AS c2#x]
+   :                    +- LocalRelation [col1#x, col2#x]
+   +- SubqueryAlias spark_catalog.default.t1
+      +- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
+         +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+            +- LocalRelation [col1#x, col2#x]
+
+
 -- !query
 SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1)
 -- !query analysis
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out
index a55b1e717be..76c9bec5fb8 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out
@@ -1258,3 +1258,84 @@ Project [id#xL]
    :           +- Range (1, 2, step=1, splits=None)
    +- SubqueryAlias t1
       +- Range (1, 3, step=1, splits=None)
+
+
+-- !query
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(t1c) FROM
+  (SELECT t1c
+   FROM   t1 JOIN t2 ON (t1a = t0a AND t2b = t1b))
+)
+-- !query analysis
+Project [t0a#x, t0b#x]
++- Filter (cast(t0a#x as bigint) < scalar-subquery#x [t0a#x])
+   :  +- Aggregate [sum(t1c#x) AS sum(t1c)#xL]
+   :     +- SubqueryAlias __auto_generated_subquery_name
+   :        +- Project [t1c#x]
+   :           +- Join Inner, ((t1a#x = outer(t0a#x)) AND (t2b#x = t1b#x))
+   :              :- SubqueryAlias t1
+   :              :  +- View (`t1`, [t1a#x,t1b#x,t1c#x])
+   :              :     +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x 
as int) AS t1b#x, cast(col3#x as int) AS t1c#x]
+   :              :        +- LocalRelation [col1#x, col2#x, col3#x]
+   :              +- SubqueryAlias t2
+   :                 +- View (`t2`, [t2a#x,t2b#x,t2c#x])
+   :                    +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x 
as int) AS t2b#x, cast(col3#x as int) AS t2c#x]
+   :                       +- LocalRelation [col1#x, col2#x, col3#x]
+   +- SubqueryAlias t0
+      +- View (`t0`, [t0a#x,t0b#x])
+         +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS 
t0b#x]
+            +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(t1c) FROM
+  (SELECT t1c
+   FROM   t1 JOIN t2 ON (t1a < t0a AND t2b >= t1b))
+)
+-- !query analysis
+Project [t0a#x, t0b#x]
++- Filter (cast(t0a#x as bigint) < scalar-subquery#x [t0a#x])
+   :  +- Aggregate [sum(t1c#x) AS sum(t1c)#xL]
+   :     +- SubqueryAlias __auto_generated_subquery_name
+   :        +- Project [t1c#x]
+   :           +- Join Inner, ((t1a#x < outer(t0a#x)) AND (t2b#x >= t1b#x))
+   :              :- SubqueryAlias t1
+   :              :  +- View (`t1`, [t1a#x,t1b#x,t1c#x])
+   :              :     +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x 
as int) AS t1b#x, cast(col3#x as int) AS t1c#x]
+   :              :        +- LocalRelation [col1#x, col2#x, col3#x]
+   :              +- SubqueryAlias t2
+   :                 +- View (`t2`, [t2a#x,t2b#x,t2c#x])
+   :                    +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x 
as int) AS t2b#x, cast(col3#x as int) AS t2c#x]
+   :                       +- LocalRelation [col1#x, col2#x, col3#x]
+   +- SubqueryAlias t0
+      +- View (`t0`, [t0a#x,t0b#x])
+         +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS 
t0b#x]
+            +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(t1c) FROM
+  (SELECT t1c
+  FROM  t1 LEFT JOIN t2 ON (t1a = t0a AND t2b = t0b))
+)
+-- !query analysis
+Project [t0a#x, t0b#x]
++- Filter (cast(t0a#x as bigint) < scalar-subquery#x [t0a#x && t0b#x])
+   :  +- Aggregate [sum(t1c#x) AS sum(t1c)#xL]
+   :     +- SubqueryAlias __auto_generated_subquery_name
+   :        +- Project [t1c#x]
+   :           +- Join LeftOuter, ((t1a#x = outer(t0a#x)) AND (t2b#x = 
outer(t0b#x)))
+   :              :- SubqueryAlias t1
+   :              :  +- View (`t1`, [t1a#x,t1b#x,t1c#x])
+   :              :     +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x 
as int) AS t1b#x, cast(col3#x as int) AS t1c#x]
+   :              :        +- LocalRelation [col1#x, col2#x, col3#x]
+   :              +- SubqueryAlias t2
+   :                 +- View (`t2`, [t2a#x,t2b#x,t2c#x])
+   :                    +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x 
as int) AS t2b#x, cast(col3#x as int) AS t2c#x]
+   :                       +- LocalRelation [col1#x, col2#x, col3#x]
+   +- SubqueryAlias t0
+      +- View (`t0`, [t0a#x,t0b#x])
+         +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS 
t0b#x]
+            +- LocalRelation [col1#x, col2#x]
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out
index 790d9da94e1..3f9eeb2cd59 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out
@@ -1802,3 +1802,83 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
     "fragment" : "SELECT sum(t0a) as d\n  FROM   t1"
   } ]
 }
+
+
+-- !query
+SELECT t0a, (SELECT sum(t1b) FROM
+  (SELECT t1b
+  FROM   t1 join t2 ON (t1a = t0a and t1b = t2b)
+  UNION ALL
+  SELECT t2b
+  FROM   t1 join t2 ON (t2a = t0a and t1a = t2a))
+)
+FROM t0
+-- !query analysis
+Project [t0a#x, scalar-subquery#x [t0a#x && t0a#x] AS scalarsubquery(t0a, 
t0a)#xL]
+:  +- Aggregate [sum(t1b#x) AS sum(t1b)#xL]
+:     +- SubqueryAlias __auto_generated_subquery_name
+:        +- Union false, false
+:           :- Project [t1b#x]
+:           :  +- Join Inner, ((t1a#x = outer(t0a#x)) AND (t1b#x = t2b#x))
+:           :     :- SubqueryAlias t1
+:           :     :  +- View (`t1`, [t1a#x,t1b#x,t1c#x])
+:           :     :     +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x 
as int) AS t1b#x, cast(col3#x as int) AS t1c#x]
+:           :     :        +- LocalRelation [col1#x, col2#x, col3#x]
+:           :     +- SubqueryAlias t2
+:           :        +- View (`t2`, [t2a#x,t2b#x,t2c#x])
+:           :           +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x 
as int) AS t2b#x, cast(col3#x as int) AS t2c#x]
+:           :              +- LocalRelation [col1#x, col2#x, col3#x]
+:           +- Project [t2b#x]
+:              +- Join Inner, ((t2a#x = outer(t0a#x)) AND (t1a#x = t2a#x))
+:                 :- SubqueryAlias t1
+:                 :  +- View (`t1`, [t1a#x,t1b#x,t1c#x])
+:                 :     +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x 
as int) AS t1b#x, cast(col3#x as int) AS t1c#x]
+:                 :        +- LocalRelation [col1#x, col2#x, col3#x]
+:                 +- SubqueryAlias t2
+:                    +- View (`t2`, [t2a#x,t2b#x,t2c#x])
+:                       +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x 
as int) AS t2b#x, cast(col3#x as int) AS t2c#x]
+:                          +- LocalRelation [col1#x, col2#x, col3#x]
++- SubqueryAlias t0
+   +- View (`t0`, [t0a#x,t0b#x])
+      +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS t0b#x]
+         +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+SELECT t0a, (SELECT sum(t1b) FROM
+  (SELECT t1b
+  FROM   t1 left join t2 ON (t1a = t0a and t1b = t2b)
+  UNION ALL
+  SELECT t2b
+  FROM   t1 join t2 ON (t2a = t0a + 1 and t1a = t2a))
+)
+FROM t0
+-- !query analysis
+Project [t0a#x, scalar-subquery#x [t0a#x && t0a#x] AS scalarsubquery(t0a, 
t0a)#xL]
+:  +- Aggregate [sum(t1b#x) AS sum(t1b)#xL]
+:     +- SubqueryAlias __auto_generated_subquery_name
+:        +- Union false, false
+:           :- Project [t1b#x]
+:           :  +- Join LeftOuter, ((t1a#x = outer(t0a#x)) AND (t1b#x = t2b#x))
+:           :     :- SubqueryAlias t1
+:           :     :  +- View (`t1`, [t1a#x,t1b#x,t1c#x])
+:           :     :     +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x 
as int) AS t1b#x, cast(col3#x as int) AS t1c#x]
+:           :     :        +- LocalRelation [col1#x, col2#x, col3#x]
+:           :     +- SubqueryAlias t2
+:           :        +- View (`t2`, [t2a#x,t2b#x,t2c#x])
+:           :           +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x 
as int) AS t2b#x, cast(col3#x as int) AS t2c#x]
+:           :              +- LocalRelation [col1#x, col2#x, col3#x]
+:           +- Project [t2b#x]
+:              +- Join Inner, ((t2a#x = (outer(t0a#x) + 1)) AND (t1a#x = 
t2a#x))
+:                 :- SubqueryAlias t1
+:                 :  +- View (`t1`, [t1a#x,t1b#x,t1c#x])
+:                 :     +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x 
as int) AS t1b#x, cast(col3#x as int) AS t1c#x]
+:                 :        +- LocalRelation [col1#x, col2#x, col3#x]
+:                 +- SubqueryAlias t2
+:                    +- View (`t2`, [t2a#x,t2b#x,t2c#x])
+:                       +- Project [cast(col1#x as int) AS t2a#x, cast(col2#x 
as int) AS t2b#x, cast(col3#x as int) AS t2c#x]
+:                          +- LocalRelation [col1#x, col2#x, col3#x]
++- SubqueryAlias t0
+   +- View (`t0`, [t0a#x,t0b#x])
+      +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS t0b#x]
+         +- LocalRelation [col1#x, col2#x]
diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql 
b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
index 29ff29d6630..2787a865975 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
@@ -101,6 +101,11 @@ SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, 
LATERAL (SELECT c1 AS a));
 -- lateral join inside correlated subquery
 SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS a) 
WHERE c1 = t1.c1);
 
+-- join condition has a correlated reference to the left side of the lateral 
join
+SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 = t4.c1 AND 
t2.c1 = t1.c1);
+SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 != t4.c1 AND 
t2.c1 != t1.c1);
+SELECT * FROM t1 LEFT JOIN lateral (SELECT * FROM t4 LEFT JOIN t2 ON t2.c1 = 
t4.c1 AND t2.c1 = t1.c1);
+
 -- COUNT bug with a single aggregate expression
 SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1);
 
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql
 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql
index e015d577549..a49f30773ca 100644
--- 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql
@@ -405,3 +405,23 @@ from range(1, 3) t1
 where (select t2.id c
        from range (1, 2) t2 where t1.id = t2.id
       ) is not null;
+
+-- Correlated references in join predicates
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(t1c) FROM
+  (SELECT t1c
+   FROM   t1 JOIN t2 ON (t1a = t0a AND t2b = t1b))
+);
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(t1c) FROM
+  (SELECT t1c
+   FROM   t1 JOIN t2 ON (t1a < t0a AND t2b >= t1b))
+);
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(t1c) FROM
+  (SELECT t1c
+  FROM  t1 LEFT JOIN t2 ON (t1a = t0a AND t2b = t0b))
+);
+
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-set-op.sql
 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-set-op.sql
index 8f03f7e4100..39e456611c0 100644
--- 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-set-op.sql
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-set-op.sql
@@ -619,3 +619,23 @@ SELECT t0a, (SELECT sum(d) FROM
   FROM   t2)
 )
 FROM t0;
+
+-- Correlated references in join predicates
+SELECT t0a, (SELECT sum(t1b) FROM
+  (SELECT t1b
+  FROM   t1 join t2 ON (t1a = t0a and t1b = t2b)
+  UNION ALL
+  SELECT t2b
+  FROM   t1 join t2 ON (t2a = t0a and t1a = t2a))
+)
+FROM t0;
+
+
+SELECT t0a, (SELECT sum(t1b) FROM
+  (SELECT t1b
+  FROM   t1 left join t2 ON (t1a = t0a and t1b = t2b)
+  UNION ALL
+  SELECT t2b
+  FROM   t1 join t2 ON (t2a = t0a + 1 and t1a = t2a))
+)
+FROM t0;
diff --git a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out 
b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
index 0bb83be0f03..33f084f3d86 100644
--- a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
@@ -572,6 +572,45 @@ struct<c1:int,c2:int>
 0      1
 
 
+-- !query
+SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 = t4.c1 AND 
t2.c1 = t1.c1)
+-- !query schema
+struct<c1:int,c2:int,c1:int,c2:int,c1:int,c2:int>
+-- !query output
+0      1       0       2       0       1
+0      1       0       2       0       2
+0      1       0       3       0       1
+0      1       0       3       0       2
+
+
+-- !query
+SELECT * FROM t1 JOIN lateral (SELECT * FROM t2 JOIN t4 ON t2.c1 != t4.c1 AND 
t2.c1 != t1.c1)
+-- !query schema
+struct<c1:int,c2:int,c1:int,c2:int,c1:int,c2:int>
+-- !query output
+1      2       0       2       1       1
+1      2       0       2       1       3
+1      2       0       3       1       1
+1      2       0       3       1       3
+
+
+-- !query
+SELECT * FROM t1 LEFT JOIN lateral (SELECT * FROM t4 LEFT JOIN t2 ON t2.c1 = 
t4.c1 AND t2.c1 = t1.c1)
+-- !query schema
+struct<c1:int,c2:int,c1:int,c2:int,c1:int,c2:int>
+-- !query output
+0      1       0       1       0       2
+0      1       0       1       0       3
+0      1       0       2       0       2
+0      1       0       2       0       3
+0      1       1       1       NULL    NULL
+0      1       1       3       NULL    NULL
+1      2       0       1       NULL    NULL
+1      2       0       2       NULL    NULL
+1      2       1       1       NULL    NULL
+1      2       1       3       NULL    NULL
+
+
 -- !query
 SELECT * FROM t1, LATERAL (SELECT COUNT(*) cnt FROM t2 WHERE c1 = t1.c1)
 -- !query schema
diff --git 
a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out
index ef5d941dc97..302c5e6dd7e 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out
@@ -660,3 +660,40 @@ where (select t2.id c
 struct<id:bigint>
 -- !query output
 1
+
+
+-- !query
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(t1c) FROM
+  (SELECT t1c
+   FROM   t1 JOIN t2 ON (t1a = t0a AND t2b = t1b))
+)
+-- !query schema
+struct<t0a:int,t0b:int>
+-- !query output
+1      1
+
+
+-- !query
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(t1c) FROM
+  (SELECT t1c
+   FROM   t1 JOIN t2 ON (t1a < t0a AND t2b >= t1b))
+)
+-- !query schema
+struct<t0a:int,t0b:int>
+-- !query output
+2      0
+
+
+-- !query
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(t1c) FROM
+  (SELECT t1c
+  FROM  t1 LEFT JOIN t2 ON (t1a = t0a AND t2b = t0b))
+)
+-- !query schema
+struct<t0a:int,t0b:int>
+-- !query output
+1      1
+2      0
diff --git 
a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out
index 2799728d48a..33a57a73be0 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-set-op.sql.out
@@ -1041,3 +1041,35 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
     "fragment" : "SELECT sum(t0a) as d\n  FROM   t1"
   } ]
 }
+
+
+-- !query
+SELECT t0a, (SELECT sum(t1b) FROM
+  (SELECT t1b
+  FROM   t1 join t2 ON (t1a = t0a and t1b = t2b)
+  UNION ALL
+  SELECT t2b
+  FROM   t1 join t2 ON (t2a = t0a and t1a = t2a))
+)
+FROM t0
+-- !query schema
+struct<t0a:int,scalarsubquery(t0a, t0a):bigint>
+-- !query output
+1      2
+2      NULL
+
+
+-- !query
+SELECT t0a, (SELECT sum(t1b) FROM
+  (SELECT t1b
+  FROM   t1 left join t2 ON (t1a = t0a and t1b = t2b)
+  UNION ALL
+  SELECT t2b
+  FROM   t1 join t2 ON (t2a = t0a + 1 and t1a = t2a))
+)
+FROM t0
+-- !query schema
+struct<t0a:int,scalarsubquery(t0a, t0a):bigint>
+-- !query output
+1      1
+2      1


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to