Repository: spark
Updated Branches:
  refs/heads/master 6e6320122 -> 95e372141


[SPARK-14781] [SQL] support nested predicate subquery

## What changes were proposed in this pull request?

In order to support nested predicate subquery, this PR introduce an internal 
join type ExistenceJoin, which will emit all the rows from left, plus an 
additional column, which presents there are any rows matched from right or not 
(it's not null-aware right now). This additional column could be used to 
replace the subquery in Filter.

In theory, all the predicate subquery could use this join type, but it's slower 
than LeftSemi and LeftAnti, so it's only used for nested subquery (subquery 
inside OR).

For example, the following SQL:
```sql
SELECT a FROM t  WHERE EXISTS (select 0) OR EXISTS (select 1)
```

This PR also fix a bug in predicate subquery push down through join (they 
should not).

Nested null-aware subquery is still not supported. For example,   `a > 3 OR b 
NOT IN (select bb from t)`

After this, we could run TPCDS query Q10, Q35, Q45

## How was this patch tested?

Added unit tests.

Author: Davies Liu <dav...@databricks.com>

Closes #12820 from davies/or_exists.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/95e37214
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/95e37214
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/95e37214

Branch: refs/heads/master
Commit: 95e372141a102f933045fe9472bbe1ce8c91b5d5
Parents: 6e63201
Author: Davies Liu <dav...@databricks.com>
Authored: Mon May 2 12:58:59 2016 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Mon May 2 12:58:59 2016 -0700

----------------------------------------------------------------------
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  5 +-
 .../sql/catalyst/expressions/subquery.scala     | 15 +++-
 .../sql/catalyst/optimizer/Optimizer.scala      | 41 +++++++--
 .../spark/sql/catalyst/plans/joinTypes.scala    | 10 +++
 .../plans/logical/basicLogicalOperators.scala   |  4 +
 .../catalyst/analysis/AnalysisErrorSuite.scala  | 11 ++-
 .../spark/sql/execution/SparkStrategies.scala   |  1 +
 .../execution/joins/BroadcastHashJoinExec.scala | 66 +++++++++++++-
 .../joins/BroadcastNestedLoopJoinExec.scala     | 94 ++++++++++++++------
 .../spark/sql/execution/joins/HashJoin.scala    | 31 ++++++-
 .../execution/joins/ShuffledHashJoinExec.scala  | 13 +--
 .../sql/execution/joins/SortMergeJoinExec.scala | 40 +++++++++
 .../org/apache/spark/sql/SubquerySuite.scala    | 25 ++++++
 .../execution/joins/ExistenceJoinSuite.scala    | 50 ++++++++++-
 14 files changed, 345 insertions(+), 61 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 61a7d9e..6e3a14d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -115,8 +115,9 @@ trait CheckAnalysis extends PredicateHelper {
           case f @ Filter(condition, child) =>
             splitConjunctivePredicates(condition).foreach {
               case _: PredicateSubquery | Not(_: PredicateSubquery) =>
-              case e if PredicateSubquery.hasPredicateSubquery(e) =>
-                failAnalysis(s"Predicate sub-queries cannot be used in nested 
conditions: $e")
+              case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) =>
+                failAnalysis(s"Null-aware predicate sub-queries cannot be used 
in nested" +
+                  s" conditions: $e")
               case e =>
             }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index cd6d3a0..eed062f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -92,7 +92,7 @@ case class PredicateSubquery(
   extends SubqueryExpression with Predicate with Unevaluable {
   override lazy val resolved = childrenResolved && query.resolved
   override lazy val references: AttributeSet = super.references -- 
query.outputSet
-  override def nullable: Boolean = false
+  override def nullable: Boolean = nullAware
   override def plan: LogicalPlan = SubqueryAlias(toString, query)
   override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query 
= plan)
   override def toString: String = s"predicate-subquery#${exprId.id} 
$conditionString"
@@ -105,6 +105,19 @@ object PredicateSubquery {
       case _ => false
     }.isDefined
   }
+
+  /**
+   * Returns whether there are any null-aware predicate subqueries inside Not. 
If not, we could
+   * turn the null-aware predicate into not-null-aware predicate.
+   */
+  def hasNullAwarePredicateWithinNot(e: Expression): Boolean = {
+    e.find{ x =>
+      x.isInstanceOf[Not] && e.find {
+        case p: PredicateSubquery => p.nullAware
+        case _ => false
+      }.isDefined
+    }.isDefined
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a147fff..e1c969f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -100,8 +100,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, 
conf: CatalystConf)
       EliminateSorts,
       SimplifyCasts,
       SimplifyCaseConversionExpressions,
-      EliminateSerialization,
-      RewritePredicateSubquery) ::
+      EliminateSerialization) ::
     Batch("Decimal Optimizations", fixedPoint,
       DecimalAggregates) ::
     Batch("Typed Filter Optimization", fixedPoint,
@@ -109,7 +108,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, 
conf: CatalystConf)
     Batch("LocalRelation", fixedPoint,
       ConvertToLocalRelation) ::
     Batch("OptimizeCodegen", Once,
-      OptimizeCodegen(conf)) :: Nil
+      OptimizeCodegen(conf)) ::
+    Batch("RewriteSubquery", Once,
+      RewritePredicateSubquery,
+      CollapseProject) :: Nil
   }
 
   /**
@@ -1078,7 +1080,14 @@ object ReorderJoin extends Rule[LogicalPlan] with 
PredicateHelper {
   def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): 
LogicalPlan = {
     assert(input.size >= 2)
     if (input.size == 2) {
-      Join(input(0), input(1), Inner, conditions.reduceLeftOption(And))
+      val (joinConditions, others) = conditions.partition(
+        e => !PredicateSubquery.hasPredicateSubquery(e))
+      val join = Join(input(0), input(1), Inner, 
joinConditions.reduceLeftOption(And))
+      if (others.nonEmpty) {
+        Filter(others.reduceLeft(And), join)
+      } else {
+        join
+      }
     } else {
       val left :: rest = input.toList
       // find out the first join that have at least one join condition
@@ -1091,7 +1100,8 @@ object ReorderJoin extends Rule[LogicalPlan] with 
PredicateHelper {
       val right = conditionalJoin.getOrElse(rest.head)
 
       val joinedRefs = left.outputSet ++ right.outputSet
-      val (joinConditions, others) = 
conditions.partition(_.references.subsetOf(joinedRefs))
+      val (joinConditions, others) = conditions.partition(
+        e => e.references.subsetOf(joinedRefs) && 
!PredicateSubquery.hasPredicateSubquery(e))
       val joined = Join(left, right, Inner, 
joinConditions.reduceLeftOption(And))
 
       // should not have reference to same logical plan
@@ -1201,9 +1211,16 @@ object PushPredicateThroughJoin extends 
Rule[LogicalPlan] with PredicateHelper {
             reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
           val newRight = rightFilterConditions.
             reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
-          val newJoinCond = (commonFilterCondition ++ 
joinCondition).reduceLeftOption(And)
+          val (newJoinConditions, others) =
+            commonFilterCondition.partition(e => 
!PredicateSubquery.hasPredicateSubquery(e))
+          val newJoinCond = (newJoinConditions ++ 
joinCondition).reduceLeftOption(And)
 
-          Join(newLeft, newRight, Inner, newJoinCond)
+          val join = Join(newLeft, newRight, Inner, newJoinCond)
+          if (others.nonEmpty) {
+            Filter(others.reduceLeft(And), join)
+          } else {
+            join
+          }
         case RightOuter =>
           // push down the right side only `where` condition
           val newLeft = left
@@ -1543,6 +1560,16 @@ object RewritePredicateSubquery extends 
Rule[LogicalPlan] with PredicateHelper {
           // Note that will almost certainly be planned as a Broadcast Nested 
Loop join. Use EXISTS
           // if performance matters to you.
           Join(p, sub, LeftAnti, Option(Or(anyNull, condition)))
+        case (p, predicate) =>
+          var joined = p
+          val replaced = predicate transformUp {
+            case PredicateSubquery(sub, conditions, nullAware, _) =>
+              // TODO: support null-aware join
+              val exists = AttributeReference("exists", BooleanType, false)()
+              joined = Join(joined, sub, ExistenceJoin(exists), 
conditions.reduceLeftOption(And))
+              exists
+          }
+          Project(p.output, Filter(replaced, joined))
       }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 13f57c5..80674d9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.plans
 
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.Attribute
 
 object JoinType {
   def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match {
@@ -69,6 +70,14 @@ case object LeftAnti extends JoinType {
   override def sql: String = "LEFT ANTI"
 }
 
+case class ExistenceJoin(exists: Attribute) extends JoinType {
+  override def sql: String = {
+    // This join type is only used in the end of optimizer and physical plans, 
we will not
+    // generate SQL for this join type
+    throw new UnsupportedOperationException
+  }
+}
+
 case class NaturalJoin(tpe: JoinType) extends JoinType {
   require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe),
     "Unsupported natural join type " + tpe)
@@ -84,6 +93,7 @@ case class UsingJoin(tpe: JoinType, usingColumns: 
Seq[UnresolvedAttribute]) exte
 object LeftExistence {
   def unapply(joinType: JoinType): Option[JoinType] = joinType match {
     case LeftSemi | LeftAnti => Some(joinType)
+    case j: ExistenceJoin => Some(joinType)
     case _ => None
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index b2297bb..830a7ac 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -273,6 +273,8 @@ case class Join(
 
   override def output: Seq[Attribute] = {
     joinType match {
+      case j: ExistenceJoin =>
+        left.output :+ j.exists
       case LeftExistence(_) =>
         left.output
       case LeftOuter =>
@@ -295,6 +297,8 @@ case class Join(
       case LeftSemi if condition.isDefined =>
         left.constraints
           .union(splitConjunctivePredicates(condition.get).toSet)
+      case j: ExistenceJoin =>
+        left.constraints
       case Inner =>
         left.constraints.union(right.constraints)
       case LeftExistence(_) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 1b08913..10bff3d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -459,11 +459,14 @@ class AnalysisErrorSuite extends AnalysisTest {
     val a = AttributeReference("a", IntegerType)()
     val b = AttributeReference("b", IntegerType)()
     val c = AttributeReference("c", BooleanType)()
-    val plan1 = Filter(Cast(In(a, Seq(ListQuery(LocalRelation(b)))), 
BooleanType), LocalRelation(a))
-    assertAnalysisError(plan1, "Predicate sub-queries cannot be used in nested 
conditions" :: Nil)
+    val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), 
BooleanType),
+      LocalRelation(a))
+    assertAnalysisError(plan1,
+      "Null-aware predicate sub-queries cannot be used in nested conditions" 
:: Nil)
 
-    val plan2 = Filter(Or(In(a, Seq(ListQuery(LocalRelation(b)))), c), 
LocalRelation(a, c))
-    assertAnalysisError(plan2, "Predicate sub-queries cannot be used in nested 
conditions" :: Nil)
+    val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), 
LocalRelation(a, c))
+    assertAnalysisError(plan2,
+      "Null-aware predicate sub-queries cannot be used in nested conditions" 
:: Nil)
   }
 
   test("PredicateSubQuery correlated predicate is nested in an illegal plan") {

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 238334e..9747e58 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -118,6 +118,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
 
     private def canBuildRight(joinType: JoinType): Boolean = joinType match {
       case Inner | LeftOuter | LeftSemi | LeftAnti => true
+      case j: ExistenceJoin => true
       case _ => false
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 587c603..7c194ab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -48,8 +48,6 @@ case class BroadcastHashJoinExec(
   override private[sql] lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
 
-  override def outputPartitioning: Partitioning = 
streamedPlan.outputPartitioning
-
   override def requiredChildDistribution: Seq[Distribution] = {
     val mode = HashedRelationBroadcastMode(buildKeys)
     buildSide match {
@@ -85,6 +83,7 @@ case class BroadcastHashJoinExec(
       case LeftOuter | RightOuter => codegenOuter(ctx, input)
       case LeftSemi => codegenSemi(ctx, input)
       case LeftAnti => codegenAnti(ctx, input)
+      case j: ExistenceJoin => codegenExistence(ctx, input)
       case x =>
         throw new IllegalArgumentException(
           s"BroadcastHashJoin should not take $x as the JoinType")
@@ -407,4 +406,67 @@ case class BroadcastHashJoinExec(
        """.stripMargin
     }
   }
+
+  /**
+   * Generates the code for existence join.
+   */
+  private def codegenExistence(ctx: CodegenContext, input: Seq[ExprCode]): 
String = {
+    val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+    val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+    val numOutput = metricTerm(ctx, "numOutputRows")
+    val existsVar = ctx.freshName("exists")
+
+    val matched = ctx.freshName("matched")
+    val buildVars = genBuildSideVars(ctx, matched)
+    val checkCondition = if (condition.isDefined) {
+      val expr = condition.get
+      // evaluate the variables from build side that used by condition
+      val eval = evaluateRequiredVariables(buildPlan.output, buildVars, 
expr.references)
+      // filter the output via condition
+      ctx.currentVars = input ++ buildVars
+      val ev =
+        BindReferences.bindReference(expr, streamedPlan.output ++ 
buildPlan.output).genCode(ctx)
+      s"""
+         |$eval
+         |${ev.code}
+         |$existsVar = !${ev.isNull} && ${ev.value};
+       """.stripMargin
+    } else {
+      s"$existsVar = true;"
+    }
+
+    val resultVar = input ++ Seq(ExprCode("", "false", existsVar))
+    if (broadcastRelation.value.keyIsUnique) {
+      s"""
+         |// generate join key for stream side
+         |${keyEv.code}
+         |// find matches from HashedRelation
+         |UnsafeRow $matched = $anyNull ? null: 
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+         |boolean $existsVar = false;
+         |if ($matched != null) {
+         |  $checkCondition
+         |}
+         |$numOutput.add(1);
+         |${consume(ctx, resultVar)}
+       """.stripMargin
+    } else {
+      val matches = ctx.freshName("matches")
+      val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+      s"""
+         |// generate join key for stream side
+         |${keyEv.code}
+         |// find matches from HashRelation
+         |$iteratorCls $matches = $anyNull ? null : 
($iteratorCls)$relationTerm.get(${keyEv.value});
+         |boolean $existsVar = false;
+         |if ($matches != null) {
+         |  while (!$existsVar && $matches.hasNext()) {
+         |    UnsafeRow $matched = (UnsafeRow) $matches.next();
+         |    $checkCondition
+         |  }
+         |}
+         |$numOutput.add(1);
+         |${consume(ctx, resultVar)}
+       """.stripMargin
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index a659bf2..2a250ec 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -50,19 +50,16 @@ case class BroadcastNestedLoopJoinExec(
       UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) 
:: Nil
   }
 
-  private[this] def genResultProjection: InternalRow => InternalRow = {
-    if (joinType == LeftSemi) {
+  private[this] def genResultProjection: InternalRow => InternalRow = joinType 
match {
+    case LeftExistence(j) =>
       UnsafeProjection.create(output, output)
-    } else {
+    case other =>
       // Always put the stream side on left to simplify implementation
       // both of left and right side could be null
       UnsafeProjection.create(
         output, (streamed.output ++ 
broadcast.output).map(_.withNullability(true)))
-    }
   }
 
-  override def outputPartitioning: Partitioning = streamed.outputPartitioning
-
   override def output: Seq[Attribute] = {
     joinType match {
       case Inner =>
@@ -73,6 +70,8 @@ case class BroadcastNestedLoopJoinExec(
         left.output.map(_.withNullability(true)) ++ right.output
       case FullOuter =>
         left.output.map(_.withNullability(true)) ++ 
right.output.map(_.withNullability(true))
+      case j: ExistenceJoin =>
+        left.output :+ j.exists
       case LeftExistence(_) =>
         left.output
       case x =>
@@ -197,6 +196,28 @@ case class BroadcastNestedLoopJoinExec(
     }
   }
 
+  private def existenceJoin(relation: Broadcast[Array[InternalRow]]): 
RDD[InternalRow] = {
+    assert(buildSide == BuildRight)
+    streamed.execute().mapPartitionsInternal { streamedIter =>
+      val buildRows = relation.value
+      val joinedRow = new JoinedRow
+
+      if (condition.isDefined) {
+        val resultRow = new GenericMutableRow(Array[Any](null))
+        streamedIter.map { row =>
+          val result = buildRows.exists(r => boundCondition(joinedRow(row, r)))
+          resultRow.setBoolean(0, result)
+          joinedRow(row, resultRow)
+        }
+      } else {
+        val resultRow = new GenericMutableRow(Array[Any](buildRows.nonEmpty))
+        streamedIter.map { row =>
+          joinedRow(row, resultRow)
+        }
+      }
+    }
+  }
+
   /**
    * The implementation for these joins:
    *
@@ -204,7 +225,8 @@ case class BroadcastNestedLoopJoinExec(
    *   RightOuter with BuildRight
    *   FullOuter
    *   LeftSemi with BuildLeft
-   *   Anti with BuildLeft
+   *   LeftAnti with BuildLeft
+   *   ExistenceJoin with BuildLeft
    */
   private def defaultJoin(relation: Broadcast[Array[InternalRow]]): 
RDD[InternalRow] = {
     /** All rows that either match both-way, or rows from streamed joined with 
nulls. */
@@ -231,27 +253,50 @@ case class BroadcastNestedLoopJoinExec(
       new BitSet(relation.value.length)
     )(_ | _)
 
-    if (joinType == LeftSemi) {
-      assert(buildSide == BuildLeft)
-      val buf: CompactBuffer[InternalRow] = new CompactBuffer()
-      var i = 0
-      val rel = relation.value
-      while (i < rel.length) {
-        if (matchedBroadcastRows.get(i)) {
-          buf += rel(i).copy()
+    joinType match {
+      case LeftSemi =>
+        assert(buildSide == BuildLeft)
+        val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+        var i = 0
+        val rel = relation.value
+        while (i < rel.length) {
+          if (matchedBroadcastRows.get(i)) {
+            buf += rel(i).copy()
+          }
+          i += 1
         }
-        i += 1
-      }
-      return sparkContext.makeRDD(buf)
+        return sparkContext.makeRDD(buf)
+      case j: ExistenceJoin =>
+        val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+        var i = 0
+        val rel = relation.value
+        while (i < rel.length) {
+          val result = new 
GenericInternalRow(Array[Any](matchedBroadcastRows.get(i)))
+          buf += new JoinedRow(rel(i).copy(), result)
+          i += 1
+        }
+        return sparkContext.makeRDD(buf)
+      case LeftAnti =>
+        val notMatched: CompactBuffer[InternalRow] = new CompactBuffer()
+        var i = 0
+        val rel = relation.value
+        while (i < rel.length) {
+          if (!matchedBroadcastRows.get(i)) {
+            notMatched += rel(i).copy()
+          }
+          i += 1
+        }
+        return sparkContext.makeRDD(notMatched)
+      case o =>
     }
 
     val notMatchedBroadcastRows: Seq[InternalRow] = {
       val nulls = new GenericMutableRow(streamed.output.size)
       val buf: CompactBuffer[InternalRow] = new CompactBuffer()
-      var i = 0
-      val buildRows = relation.value
       val joinedRow = new JoinedRow
       joinedRow.withLeft(nulls)
+      var i = 0
+      val buildRows = relation.value
       while (i < buildRows.length) {
         if (!matchedBroadcastRows.get(i)) {
           buf += joinedRow.withRight(buildRows(i)).copy()
@@ -261,10 +306,6 @@ case class BroadcastNestedLoopJoinExec(
       buf
     }
 
-    if (joinType == LeftAnti) {
-      return sparkContext.makeRDD(notMatchedBroadcastRows)
-    }
-
     val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter =>
       val buildRows = relation.value
       val joinedRow = new JoinedRow
@@ -308,13 +349,16 @@ case class BroadcastNestedLoopJoinExec(
         leftExistenceJoin(broadcastedRelation, exists = true)
       case (LeftAnti, BuildRight) =>
         leftExistenceJoin(broadcastedRelation, exists = false)
+      case (j: ExistenceJoin, BuildRight) =>
+        existenceJoin(broadcastedRelation)
       case _ =>
         /**
          * LeftOuter with BuildLeft
          * RightOuter with BuildRight
          * FullOuter
          * LeftSemi with BuildLeft
-         * Anti with BuildLeft
+         * LeftAnti with BuildLeft
+         * ExistenceJoin with BuildLeft
          */
         defaultJoin(broadcastedRelation)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 9c173d7..d46a804 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.{IntegralType, LongType}
@@ -43,6 +44,8 @@ trait HashJoin {
         left.output ++ right.output.map(_.withNullability(true))
       case RightOuter =>
         left.output.map(_.withNullability(true)) ++ right.output
+      case j: ExistenceJoin =>
+        left.output :+ j.exists
       case LeftExistence(_) =>
         left.output
       case x =>
@@ -50,6 +53,8 @@ trait HashJoin {
     }
   }
 
+  override def outputPartitioning: Partitioning = 
streamedPlan.outputPartitioning
+
   protected lazy val (buildPlan, streamedPlan) = buildSide match {
     case BuildLeft => (left, right)
     case BuildRight => (right, left)
@@ -110,15 +115,14 @@ trait HashJoin {
     (r: InternalRow) => true
   }
 
-  protected def createResultProjection(): (InternalRow) => InternalRow = {
-    if (joinType == LeftSemi) {
+  protected def createResultProjection(): (InternalRow) => InternalRow = 
joinType match {
+    case LeftExistence(_) =>
       UnsafeProjection.create(output, output)
-    } else {
+    case _ =>
       // Always put the stream side on left to simplify implementation
       // both of left and right side could be null
       UnsafeProjection.create(
         output, (streamedPlan.output ++ 
buildPlan.output).map(_.withNullability(true)))
-    }
   }
 
   private def innerJoin(
@@ -184,6 +188,23 @@ trait HashJoin {
     }
   }
 
+  private def existenceJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
+    val result = new GenericMutableRow(Array[Any](null))
+    val joinedRow = new JoinedRow
+    streamIter.map { current =>
+      val key = joinKeys(current)
+      lazy val buildIter = hashedRelation.get(key)
+      val exists = !key.anyNull && buildIter != null && (condition.isEmpty || 
buildIter.exists {
+        (row: InternalRow) => boundCondition(joinedRow(current, row))
+      })
+      result.setBoolean(0, exists)
+      joinedRow(current, result)
+    }
+  }
+
   private def antiJoin(
       streamIter: Iterator[InternalRow],
       hashedRelation: HashedRelation): Iterator[InternalRow] = {
@@ -212,6 +233,8 @@ trait HashJoin {
         semiJoin(streamedIter, hashed)
       case LeftAnti =>
         antiJoin(streamedIter, hashed)
+      case j: ExistenceJoin =>
+        existenceJoin(streamedIter, hashed)
       case x =>
         throw new IllegalArgumentException(
           s"BroadcastHashJoin should not take $x as the JoinType")

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 3ef2fec..0036f9a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
 import org.apache.spark.TaskContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
@@ -44,17 +44,6 @@ case class ShuffledHashJoinExec(
     "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of 
build side"),
     "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build 
hash map"))
 
-  override def outputPartitioning: Partitioning = joinType match {
-    case Inner => PartitioningCollection(Seq(left.outputPartitioning, 
right.outputPartitioning))
-    case LeftAnti => left.outputPartitioning
-    case LeftSemi => left.outputPartitioning
-    case LeftOuter => left.outputPartitioning
-    case RightOuter => right.outputPartitioning
-    case FullOuter => 
UnknownPartitioning(left.outputPartitioning.numPartitions)
-    case x =>
-      throw new IllegalArgumentException(s"ShuffledHashJoin should not take $x 
as the JoinType")
-  }
-
   override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 775f8ac..f0efa52 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -53,6 +53,8 @@ case class SortMergeJoinExec(
         left.output.map(_.withNullability(true)) ++ right.output
       case FullOuter =>
         (left.output ++ right.output).map(_.withNullability(true))
+      case j: ExistenceJoin =>
+        left.output :+ j.exists
       case LeftExistence(_) =>
         left.output
       case x =>
@@ -269,6 +271,44 @@ case class SortMergeJoinExec(
             override def getRow: InternalRow = currentLeftRow
           }.toScala
 
+        case j: ExistenceJoin =>
+          new RowIterator {
+            private[this] var currentLeftRow: InternalRow = _
+            private[this] val result: MutableRow = new 
GenericMutableRow(Array[Any](null))
+            private[this] val smjScanner = new SortMergeJoinScanner(
+              createLeftKeyGenerator(),
+              createRightKeyGenerator(),
+              keyOrdering,
+              RowIterator.fromScala(leftIter),
+              RowIterator.fromScala(rightIter)
+            )
+            private[this] val joinRow = new JoinedRow
+
+            override def advanceNext(): Boolean = {
+              while (smjScanner.findNextOuterJoinRows()) {
+                currentLeftRow = smjScanner.getStreamedRow
+                val currentRightMatches = smjScanner.getBufferedMatches
+                var found = false
+                if (currentRightMatches != null) {
+                  var i = 0
+                  while (!found && i < currentRightMatches.length) {
+                    joinRow(currentLeftRow, currentRightMatches(i))
+                    if (boundCondition(joinRow)) {
+                      found = true
+                    }
+                    i += 1
+                  }
+                }
+                result.setBoolean(0, found)
+                numOutputRows += 1
+                return true
+              }
+              false
+            }
+
+            override def getRow: InternalRow = 
resultProj(joinRow(currentLeftRow, result))
+          }.toScala
+
         case x =>
           throw new IllegalArgumentException(
             s"SortMergeJoin should not take $x as the JoinType")

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 0bf4c6f..ff3f9bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -152,6 +152,19 @@ class SubquerySuite extends QueryTest with 
SharedSQLContext {
       Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
   }
 
+  test("EXISTS predicate subquery within OR") {
+    checkAnswer(
+      sql("select * from l where exists (select * from r where l.a = r.c)" +
+        " or exists (select * from r where l.a = r.c)"),
+      Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil)
+
+    checkAnswer(
+      sql("select * from l where not exists (select * from r where l.a = r.c 
and l.b < r.d)" +
+        " or not exists (select * from r where l.a = r.c)"),
+      Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) ::
+        Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil)
+  }
+
   test("IN predicate subquery") {
     checkAnswer(
       sql("select * from l where l.a in (select c from r)"),
@@ -187,6 +200,18 @@ class SubquerySuite extends QueryTest with 
SharedSQLContext {
 
   }
 
+  test("IN predicate subquery within OR") {
+    checkAnswer(
+      sql("select * from l where l.a in (select c from r)" +
+        " or l.a in (select c from r where l.b < r.d)"),
+      Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil)
+
+    intercept[AnalysisException] {
+      sql("select * from l where a not in (select c from r)" +
+        " or a not in (select c from r where c is not null)")
+    }
+  }
+
   test("complex IN predicate subquery") {
     checkAnswer(
       sql("select * from l where (a, b) not in (select c, d from r)"),

http://git-wip-us.apache.org/repos/asf/spark/blob/95e37214/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index b32b644..8093054 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -18,15 +18,15 @@
 package org.apache.spark.sql.execution.joins
 
 import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
-import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, 
LeftSemi}
+import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan, 
SparkPlanTest}
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
+import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, 
StructType}
 
 class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
 
@@ -89,6 +89,18 @@ class ExistenceJoinSuite extends SparkPlanTest with 
SharedSQLContext {
       ExtractEquiJoinKeys.unapply(join)
     }
 
+    val existsAttr = AttributeReference("exists", BooleanType, false)()
+    val leftSemiPlus = ExistenceJoin(existsAttr)
+    def createLeftSemiPlusJoin(join: SparkPlan): SparkPlan = {
+      val output = join.output.dropRight(1)
+      val condition = if (joinType == LeftSemi) {
+        existsAttr
+      } else {
+        Not(existsAttr)
+      }
+      ProjectExec(output, FilterExec(condition, join))
+    }
+
     test(s"$testName using ShuffledHashJoin") {
       extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
@@ -98,6 +110,12 @@ class ExistenceJoinSuite extends SparkPlanTest with 
SharedSQLContext {
                 leftKeys, rightKeys, joinType, BuildRight, boundCondition, 
left, right)),
             expectedAnswer,
             sortAnswers = true)
+          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+            EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+              createLeftSemiPlusJoin(ShuffledHashJoinExec(
+                leftKeys, rightKeys, leftSemiPlus, BuildRight, boundCondition, 
left, right))),
+            expectedAnswer,
+            sortAnswers = true)
         }
       }
     }
@@ -111,6 +129,12 @@ class ExistenceJoinSuite extends SparkPlanTest with 
SharedSQLContext {
                 leftKeys, rightKeys, joinType, BuildRight, boundCondition, 
left, right)),
             expectedAnswer,
             sortAnswers = true)
+          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+            EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+              createLeftSemiPlusJoin(BroadcastHashJoinExec(
+                leftKeys, rightKeys, leftSemiPlus, BuildRight, boundCondition, 
left, right))),
+            expectedAnswer,
+            sortAnswers = true)
         }
       }
     }
@@ -123,6 +147,12 @@ class ExistenceJoinSuite extends SparkPlanTest with 
SharedSQLContext {
               SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, 
left, right)),
             expectedAnswer,
             sortAnswers = true)
+          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>
+            EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+              createLeftSemiPlusJoin(SortMergeJoinExec(
+                leftKeys, rightKeys, leftSemiPlus, boundCondition, left, 
right))),
+            expectedAnswer,
+            sortAnswers = true)
         }
       }
     }
@@ -134,6 +164,12 @@ class ExistenceJoinSuite extends SparkPlanTest with 
SharedSQLContext {
             BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, 
Some(condition))),
           expectedAnswer,
           sortAnswers = true)
+        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) 
=>
+          EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+            createLeftSemiPlusJoin(BroadcastNestedLoopJoinExec(
+              left, right, BuildLeft, leftSemiPlus, Some(condition)))),
+          expectedAnswer,
+          sortAnswers = true)
       }
     }
 
@@ -144,6 +180,12 @@ class ExistenceJoinSuite extends SparkPlanTest with 
SharedSQLContext {
             BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, 
Some(condition))),
           expectedAnswer,
           sortAnswers = true)
+        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) 
=>
+          EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+            createLeftSemiPlusJoin(BroadcastNestedLoopJoinExec(
+              left, right, BuildRight, leftSemiPlus, Some(condition)))),
+          expectedAnswer,
+          sortAnswers = true)
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to