This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 78cc91c  [SPARK-32567][SQL] Add code-gen for full outer shuffled hash 
join
78cc91c is described below

commit 78cc91c962abd48d7ec2e9721d1e1429f802dced
Author: Cheng Su <chen...@fb.com>
AuthorDate: Wed Nov 3 11:18:12 2021 +0800

    [SPARK-32567][SQL] Add code-gen for full outer shuffled hash join
    
    ### What changes were proposed in this pull request?
    
    As title. This PR is to add code-gen support for FULL OUTER shuffled hash 
join.
    
    The main change is in `ShuffledHashJoinExec.scala:doProduce()` to generate 
code for FULL OUTER join.
    * `ShuffledHashJoinExec.scala:codegenFullOuterJoinWithUniqueKey()` is the 
code for join with unique join key from build side.
    * `ShuffledHashJoinExec.scala:codegenFullOuterJoinWithNonUniqueKey()` is 
the code for join with non-unique key.
    
    Example query:
    
    ```
    val df1 = spark.range(5).select($"id".as("k1"))
    val df2 = spark.range(10).select($"id".as("k2"))
    df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2" % 3 && $"k1" + 3 =!= 
$"k2", "full_outer")
    ```
    
    Generated code for example query: 
https://gist.github.com/c21/828b782ee81827f4148939cb50314a7b
    
    ### Why are the changes needed?
    
    Improve query performance for FULL OUTER shuffled hash join.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    * Added unit test in `WholeStageCodegenSuite`.
    * Existing unit test in `OuterJoinSuite`.
    
    Closes #34444 from c21/shj-codegen.
    
    Authored-by: Cheng Su <chen...@fb.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../joins/BroadcastNestedLoopJoinExec.scala        |   2 +-
 .../spark/sql/execution/joins/HashJoin.scala       |   4 +-
 .../sql/execution/joins/JoinCodegenSupport.scala   |  50 ++--
 .../sql/execution/joins/ShuffledHashJoinExec.scala | 260 ++++++++++++++++++++-
 .../sql/execution/joins/SortMergeJoinExec.scala    |   3 +-
 .../sql/execution/WholeStageCodegenSuite.scala     |  45 +++-
 .../spark/sql/execution/joins/OuterJoinSuite.scala |  21 ++
 7 files changed, 349 insertions(+), 36 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index 77a30b7..0677211 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -463,7 +463,7 @@ case class BroadcastNestedLoopJoinExec(
   private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String 
= {
     val (buildRowArray, buildRowArrayTerm) = prepareBroadcast(ctx)
     val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, 
broadcast)
-    val buildVars = genBuildSideVars(ctx, buildRow, broadcast)
+    val buildVars = genOneSideJoinVars(ctx, buildRow, broadcast, 
setDefaultValue = true)
 
     val resultVars = buildSide match {
       case BuildLeft => buildVars ++ input
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index f87acb8..0e8bb84 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -444,7 +444,7 @@ trait HashJoin extends JoinCodegenSupport {
     val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx)
     val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
     val matched = ctx.freshName("matched")
-    val buildVars = genBuildSideVars(ctx, matched, buildPlan)
+    val buildVars = genOneSideJoinVars(ctx, matched, buildPlan, 
setDefaultValue = true)
     val numOutput = metricTerm(ctx, "numOutputRows")
 
     // filter the output via condition
@@ -646,7 +646,7 @@ trait HashJoin extends JoinCodegenSupport {
     val existsVar = ctx.freshName("exists")
 
     val matched = ctx.freshName("matched")
-    val buildVars = genBuildSideVars(ctx, matched, buildPlan)
+    val buildVars = genOneSideJoinVars(ctx, matched, buildPlan, 
setDefaultValue = false)
     val checkCondition = if (condition.isDefined) {
       val expr = condition.get
       // evaluate the variables from build side that used by condition
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala
index 96aa0be..75f0a35 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins
 import org.apache.spark.sql.catalyst.expressions.{BindReferences, 
BoundReference}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, InnerLike, 
LeftAnti, LeftOuter, LeftSemi, RightOuter}
 import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
 
 /**
@@ -30,7 +29,7 @@ trait JoinCodegenSupport extends CodegenSupport with 
BaseJoinExec {
 
   /**
    * Generate the (non-equi) condition used to filter joined rows.
-   * This is used in Inner, Left Semi and Left Anti joins.
+   * This is used in Inner, Left Semi, Left Anti and Full Outer joins.
    *
    * @return Tuple of variable name for row of build side, generated code for 
condition,
    *         and generated code for variables of build side.
@@ -39,13 +38,15 @@ trait JoinCodegenSupport extends CodegenSupport with 
BaseJoinExec {
       ctx: CodegenContext,
       streamVars: Seq[ExprCode],
       streamPlan: SparkPlan,
-      buildPlan: SparkPlan): (String, String, Seq[ExprCode]) = {
-    val buildRow = ctx.freshName("buildRow")
-    val buildVars = genBuildSideVars(ctx, buildRow, buildPlan)
+      buildPlan: SparkPlan,
+      buildRow: Option[String] = None): (String, String, Seq[ExprCode]) = {
+    val buildSideRow = buildRow.getOrElse(ctx.freshName("buildRow"))
+    val buildVars = genOneSideJoinVars(ctx, buildSideRow, buildPlan, 
setDefaultValue = false)
     val checkCondition = if (condition.isDefined) {
       val expr = condition.get
       // evaluate the variables from build side that used by condition
       val eval = evaluateRequiredVariables(buildPlan.output, buildVars, 
expr.references)
+
       // filter the output via condition
       ctx.currentVars = streamVars ++ buildVars
       val ev =
@@ -59,41 +60,38 @@ trait JoinCodegenSupport extends CodegenSupport with 
BaseJoinExec {
     } else {
       ""
     }
-    (buildRow, checkCondition, buildVars)
+    (buildSideRow, checkCondition, buildVars)
   }
 
   /**
-   * Generates the code for variables of build side.
+   * Generates the code for variables of one child side of join.
    */
-  protected def genBuildSideVars(
+  protected def genOneSideJoinVars(
       ctx: CodegenContext,
-      buildRow: String,
-      buildPlan: SparkPlan): Seq[ExprCode] = {
+      row: String,
+      plan: SparkPlan,
+      setDefaultValue: Boolean): Seq[ExprCode] = {
     ctx.currentVars = null
-    ctx.INPUT_ROW = buildRow
-    buildPlan.output.zipWithIndex.map { case (a, i) =>
+    ctx.INPUT_ROW = row
+    plan.output.zipWithIndex.map { case (a, i) =>
       val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
-      joinType match {
-        case _: InnerLike | LeftSemi | LeftAnti | _: ExistenceJoin =>
-          ev
-        case LeftOuter | RightOuter =>
-          // the variables are needed even there is no matched rows
-          val isNull = ctx.freshName("isNull")
-          val value = ctx.freshName("value")
-          val javaType = CodeGenerator.javaType(a.dataType)
-          val code = code"""
+      if (setDefaultValue) {
+        // the variables are needed even there is no matched rows
+        val isNull = ctx.freshName("isNull")
+        val value = ctx.freshName("value")
+        val javaType = CodeGenerator.javaType(a.dataType)
+        val code = code"""
             |boolean $isNull = true;
             |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
-            |if ($buildRow != null) {
+            |if ($row != null) {
             |  ${ev.code}
             |  $isNull = ${ev.isNull};
             |  $value = ${ev.value};
             |}
           """.stripMargin
-          ExprCode(code, JavaCode.isNullVariable(isNull), 
JavaCode.variable(value, a.dataType))
-        case _ =>
-          throw new IllegalArgumentException(
-            s"JoinCodegenSupport.genBuildSideVars should not take $joinType as 
the JoinType")
+        ExprCode(code, JavaCode.isNullVariable(isNull), 
JavaCode.variable(value, a.dataType))
+      } else {
+        ev
       }
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 47b2bd2..7136229 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -311,11 +311,6 @@ case class ShuffledHashJoinExec(
     streamResultIter ++ buildResultIter
   }
 
-  // TODO(SPARK-32567): support full outer shuffled hash join code-gen
-  override def supportCodegen: Boolean = {
-    joinType != FullOuter
-  }
-
   override def inputRDDs(): Seq[RDD[InternalRow]] = {
     streamedPlan.execute() :: buildPlan.execute() :: Nil
   }
@@ -332,6 +327,261 @@ case class ShuffledHashJoinExec(
     HashedRelationInfo(relationTerm, keyIsUnique = false, isEmpty = false)
   }
 
+  override def doProduce(ctx: CodegenContext): String = {
+    // Specialize `doProduce` code for full outer join, because full outer 
join needs to
+    // iterate streamed and build side separately.
+    if (joinType != FullOuter) {
+      return super.doProduce(ctx)
+    }
+
+    val HashedRelationInfo(relationTerm, _, _) = prepareRelation(ctx)
+
+    // Inline mutable state since not many join operations in a task
+    val keyIsUnique = ctx.addMutableState("boolean", "keyIsUnique",
+      v => s"$v = $relationTerm.keyIsUnique();", forceInline = true)
+    val streamedInput = ctx.addMutableState("scala.collection.Iterator", 
"streamedInput",
+      v => s"$v = inputs[0];", forceInline = true)
+    val buildInput = ctx.addMutableState("scala.collection.Iterator", 
"buildInput",
+      v => s"$v = $relationTerm.valuesWithKeyIndex();", forceInline = true)
+    val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", 
forceInline = true)
+    val buildRow = ctx.addMutableState("InternalRow", "buildRow", forceInline 
= true)
+
+    // Generate variables and related code from streamed side
+    val streamedVars = genOneSideJoinVars(ctx, streamedRow, streamedPlan, 
setDefaultValue = false)
+    val streamedKeyVariables = evaluateRequiredVariables(streamedOutput, 
streamedVars,
+      AttributeSet.fromAttributeSets(streamedKeys.map(_.references)))
+    ctx.currentVars = streamedVars
+    val streamedKeyExprCode = GenerateUnsafeProjection.createCode(ctx, 
streamedBoundKeys)
+    val streamedKeyEv =
+      s"""
+         |$streamedKeyVariables
+         |${streamedKeyExprCode.code}
+       """.stripMargin
+    val streamedKeyAnyNull = s"${streamedKeyExprCode.value}.anyNull()"
+
+    // Generate code for join condition
+    val (_, conditionCheck, _) =
+      getJoinCondition(ctx, streamedVars, streamedPlan, buildPlan, 
Some(buildRow))
+
+    // Generate code for result output in separate function, as we need to 
output result from
+    // multiple places in join code.
+    val streamedResultVars = genOneSideJoinVars(
+      ctx, streamedRow, streamedPlan, setDefaultValue = true)
+    val buildResultVars = genOneSideJoinVars(
+      ctx, buildRow, buildPlan, setDefaultValue = true)
+    val resultVars = buildSide match {
+      case BuildLeft => buildResultVars ++ streamedResultVars
+      case BuildRight => streamedResultVars ++ buildResultVars
+    }
+    val consumeFullOuterJoinRow = ctx.freshName("consumeFullOuterJoinRow")
+    ctx.addNewFunction(consumeFullOuterJoinRow,
+      s"""
+         |private void $consumeFullOuterJoinRow() {
+         |  ${metricTerm(ctx, "numOutputRows")}.add(1);
+         |  ${consume(ctx, resultVars)}
+         |}
+       """.stripMargin)
+
+    val joinWithUniqueKey = codegenFullOuterJoinWithUniqueKey(
+      ctx, (streamedRow, buildRow), (streamedInput, buildInput), 
streamedKeyEv, streamedKeyAnyNull,
+      streamedKeyExprCode.value, relationTerm, conditionCheck, 
consumeFullOuterJoinRow)
+    val joinWithNonUniqueKey = codegenFullOuterJoinWithNonUniqueKey(
+      ctx, (streamedRow, buildRow), (streamedInput, buildInput), 
streamedKeyEv, streamedKeyAnyNull,
+      streamedKeyExprCode.value, relationTerm, conditionCheck, 
consumeFullOuterJoinRow)
+
+    s"""
+       |if ($keyIsUnique) {
+       |  $joinWithUniqueKey
+       |} else {
+       |  $joinWithNonUniqueKey
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * Generates the code for full outer join with unique join keys.
+   * This is code-gen version of `fullOuterJoinWithUniqueKey()`.
+   */
+  private def codegenFullOuterJoinWithUniqueKey(
+      ctx: CodegenContext,
+      rows: (String, String),
+      inputs: (String, String),
+      streamedKeyEv: String,
+      streamedKeyAnyNull: String,
+      streamedKeyValue: ExprValue,
+      relationTerm: String,
+      conditionCheck: String,
+      consumeFullOuterJoinRow: String): String = {
+    // Inline mutable state since not many join operations in a task
+    val matchedKeySetClsName = classOf[BitSet].getName
+    val matchedKeySet = ctx.addMutableState(matchedKeySetClsName, 
"matchedKeySet",
+      v => s"$v = new 
$matchedKeySetClsName($relationTerm.maxNumKeysIndex());", forceInline = true)
+    val rowWithIndexClsName = classOf[ValueRowWithKeyIndex].getName
+    val rowWithIndex = ctx.freshName("rowWithIndex")
+    val foundMatch = ctx.freshName("foundMatch")
+    val (streamedRow, buildRow) = rows
+    val (streamedInput, buildInput) = inputs
+
+    val joinStreamSide =
+      s"""
+         |while ($streamedInput.hasNext()) {
+         |  $streamedRow = (InternalRow) $streamedInput.next();
+         |
+         |  // generate join key for stream side
+         |  $streamedKeyEv
+         |
+         |  // find matches from HashedRelation
+         |  boolean $foundMatch = false;
+         |  $buildRow = null;
+         |  $rowWithIndexClsName $rowWithIndex = $streamedKeyAnyNull ? null:
+         |    $relationTerm.getValueWithKeyIndex($streamedKeyValue);
+         |
+         |  if ($rowWithIndex != null) {
+         |    $buildRow = $rowWithIndex.getValue();
+         |    // check join condition
+         |    $conditionCheck {
+         |      // set key index in matched keys set
+         |      $matchedKeySet.set($rowWithIndex.getKeyIndex());
+         |      $foundMatch = true;
+         |    }
+         |
+         |    if (!$foundMatch) {
+         |      $buildRow = null;
+         |    }
+         |  }
+         |
+         |  $consumeFullOuterJoinRow();
+         |  if (shouldStop()) return;
+         |}
+       """.stripMargin
+
+    val filterBuildSide =
+      s"""
+         |$streamedRow = null;
+         |
+         |// find non-matched rows from HashedRelation
+         |while ($buildInput.hasNext()) {
+         |  $rowWithIndexClsName $rowWithIndex = ($rowWithIndexClsName) 
$buildInput.next();
+         |
+         |  // check if key index is not in matched keys set
+         |  if (!$matchedKeySet.get($rowWithIndex.getKeyIndex())) {
+         |    $buildRow = $rowWithIndex.getValue();
+         |    $consumeFullOuterJoinRow();
+         |  }
+         |
+         |  if (shouldStop()) return;
+         |}
+       """.stripMargin
+
+    s"""
+       |$joinStreamSide
+       |$filterBuildSide
+     """.stripMargin
+  }
+
+  /**
+   * Generates the code for full outer join with non-unique join keys.
+   * This is code-gen version of `fullOuterJoinWithNonUniqueKey()`.
+   */
+  private def codegenFullOuterJoinWithNonUniqueKey(
+      ctx: CodegenContext,
+      rows: (String, String),
+      inputs: (String, String),
+      streamedKeyEv: String,
+      streamedKeyAnyNull: String,
+      streamedKeyValue: ExprValue,
+      relationTerm: String,
+      conditionCheck: String,
+      consumeFullOuterJoinRow: String): String = {
+    // Inline mutable state since not many join operations in a task
+    val matchedRowSetClsName = classOf[OpenHashSet[_]].getName
+    val matchedRowSet = ctx.addMutableState(matchedRowSetClsName, 
"matchedRowSet",
+      v => s"$v = new 
$matchedRowSetClsName(scala.reflect.ClassTag$$.MODULE$$.Long());",
+      forceInline = true)
+    val prevKeyIndex = ctx.addMutableState("int", "prevKeyIndex",
+      v => s"$v = -1;", forceInline = true)
+    val valueIndex = ctx.addMutableState("int", "valueIndex",
+      v => s"$v = -1;", forceInline = true)
+    val rowWithIndexClsName = classOf[ValueRowWithKeyIndex].getName
+    val rowWithIndex = ctx.freshName("rowWithIndex")
+    val buildIterator = ctx.freshName("buildIterator")
+    val foundMatch = ctx.freshName("foundMatch")
+    val keyIndex = ctx.freshName("keyIndex")
+    val (streamedRow, buildRow) = rows
+    val (streamedInput, buildInput) = inputs
+
+    val rowIndex = s"(((long)$keyIndex) << 32) | $valueIndex"
+
+    val joinStreamSide =
+      s"""
+         |while ($streamedInput.hasNext()) {
+         |  $streamedRow = (InternalRow) $streamedInput.next();
+         |
+         |  // generate join key for stream side
+         |  $streamedKeyEv
+         |
+         |  // find matches from HashedRelation
+         |  boolean $foundMatch = false;
+         |  $buildRow = null;
+         |  scala.collection.Iterator $buildIterator = $streamedKeyAnyNull ? 
null:
+         |    $relationTerm.getWithKeyIndex($streamedKeyValue);
+         |
+         |  int $valueIndex = -1;
+         |  while ($buildIterator != null && $buildIterator.hasNext()) {
+         |    $rowWithIndexClsName $rowWithIndex = ($rowWithIndexClsName) 
$buildIterator.next();
+         |    int $keyIndex = $rowWithIndex.getKeyIndex();
+         |    $buildRow = $rowWithIndex.getValue();
+         |    $valueIndex++;
+         |
+         |    // check join condition
+         |    $conditionCheck {
+         |      // set row index in matched row set
+         |      $matchedRowSet.add($rowIndex);
+         |      $foundMatch = true;
+         |      $consumeFullOuterJoinRow();
+         |    }
+         |  }
+         |
+         |  if (!$foundMatch) {
+         |    $buildRow = null;
+         |    $consumeFullOuterJoinRow();
+         |  }
+         |
+         |  if (shouldStop()) return;
+         |}
+       """.stripMargin
+
+    val filterBuildSide =
+      s"""
+         |$streamedRow = null;
+         |
+         |// find non-matched rows from HashedRelation
+         |while ($buildInput.hasNext()) {
+         |  $rowWithIndexClsName $rowWithIndex = ($rowWithIndexClsName) 
$buildInput.next();
+         |  int $keyIndex = $rowWithIndex.getKeyIndex();
+         |  if ($prevKeyIndex == -1 || $keyIndex != $prevKeyIndex) {
+         |    $valueIndex = 0;
+         |    $prevKeyIndex = $keyIndex;
+         |  } else {
+         |    $valueIndex += 1;
+         |  }
+         |
+         |  // check if row index is not in matched row set
+         |  if (!$matchedRowSet.contains($rowIndex)) {
+         |    $buildRow = $rowWithIndex.getValue();
+         |    $consumeFullOuterJoinRow();
+         |  }
+         |
+         |  if (shouldStop()) return;
+         |}
+       """.stripMargin
+
+    s"""
+       |$joinStreamSide
+       |$filterBuildSide
+     """.stripMargin
+  }
+
   override protected def withNewChildrenInternal(
       newLeft: SparkPlan, newRight: SparkPlan): ShuffledHashJoinExec =
     copy(left = newLeft, right = newRight)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 18f584b..66054bf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -655,7 +655,8 @@ case class SortMergeJoinExec(
     // Create variables for row from both sides.
     val (streamedVars, streamedVarDecl) = createStreamedVars(ctx, streamedRow)
     val bufferedRow = ctx.freshName("bufferedRow")
-    val bufferedVars = genBuildSideVars(ctx, bufferedRow, bufferedPlan)
+    val setDefaultValue = joinType == LeftOuter || joinType == RightOuter
+    val bufferedVars = genOneSideJoinVars(ctx, bufferedRow, bufferedPlan, 
setDefaultValue)
 
     val iterator = ctx.freshName("iterator")
     val numOutput = metricTerm(ctx, "numOutputRows")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 6cc6e33..7da813c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -149,7 +149,7 @@ class WholeStageCodegenSuite extends QueryTest with 
SharedSparkSession
     assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, 
"2")))
   }
 
-  test("ShuffledHashJoin should be included in WholeStageCodegen") {
+  test("Inner ShuffledHashJoin should be included in WholeStageCodegen") {
     val df1 = spark.range(5).select($"id".as("k1"))
     val df2 = spark.range(15).select($"id".as("k2"))
     val df3 = spark.range(6).select($"id".as("k3"))
@@ -171,6 +171,49 @@ class WholeStageCodegenSuite extends QueryTest with 
SharedSparkSession
       Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 
4)))
   }
 
+  test("Full Outer ShuffledHashJoin should be included in WholeStageCodegen") {
+    val df1 = spark.range(5).select($"id".as("k1"))
+    val df2 = spark.range(10).select($"id".as("k2"))
+    val df3 = spark.range(3).select($"id".as("k3"))
+
+    // test one join with unique key from build side
+    val joinUniqueDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", 
"full_outer")
+    assert(joinUniqueDF.queryExecution.executedPlan.collect {
+      case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
+    }.size === 1)
+    checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), 
Row(4, 4),
+      Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9)))
+
+    // test one join with non-unique key from build side
+    val joinNonUniqueDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2" % 
3, "full_outer")
+    assert(joinNonUniqueDF.queryExecution.executedPlan.collect {
+      case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
+    }.size === 1)
+    checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 
9), Row(1, 1),
+      Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), 
Row(4, null)))
+
+    // test one join with non-equi condition
+    val joinWithNonEquiDF = df1.join(df2.hint("SHUFFLE_HASH"),
+      $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer")
+    assert(joinWithNonEquiDF.queryExecution.executedPlan.collect {
+      case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
+    }.size === 1)
+    checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 
1),
+      Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 
3), Row(null, 4),
+      Row(null, 5)))
+
+    // test two joins
+    val twoJoinsDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", 
"full_outer")
+      .join(df3.hint("SHUFFLE_HASH"), $"k1" === $"k3" && $"k1" + $"k3" =!= 2, 
"full_outer")
+    assert(twoJoinsDF.queryExecution.executedPlan.collect {
+      case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
+    }.size === 2)
+    checkAnswer(twoJoinsDF,
+      Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 
4, null),
+        Row(null, 5, null), Row(null, 6, null), Row(null, 7, null), Row(null, 
8, null),
+        Row(null, 9, null), Row(null, null, 1)))
+  }
+
   test("Left/Right Outer SortMergeJoin should be included in 
WholeStageCodegen") {
     val df1 = spark.range(10).select($"id".as("k1"))
     val df2 = spark.range(4).select($"id".as("k2"))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 229d756..4f78833 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -304,4 +304,25 @@ class OuterJoinSuite extends SparkPlanTest with 
SharedSparkSession {
       (null, null, 7, 7.0)
     )
   )
+
+  testOuterJoin(
+    "full outer join with unique keys",
+    uniqueLeft,
+    uniqueRight,
+    FullOuter,
+    uniqueCondition,
+    Seq(
+      (null, null, null, null),
+      (null, null, null, null),
+      (1, 2.0, null, null),
+      (2, 1.0, 2, 3.0),
+      (3, 3.0, null, null),
+      (5, 1.0, 5, 3.0),
+      (6, 6.0, null, null),
+      (null, null, 0, 0.0),
+      (null, null, 3, 2.0),
+      (null, null, 4, 1.0),
+      (null, null, 7, 7.0)
+    )
+  )
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to