This is an automated email from the ASF dual-hosted git repository.

huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2db8cfb3bd9 [SPARK-44060][SQL] Code-gen for build side outer shuffled 
hash join
2db8cfb3bd9 is described below

commit 2db8cfb3bd9bf5e85379c6d5ca414d36cfd9292d
Author: Szehon Ho <szehon.apa...@gmail.com>
AuthorDate: Fri Jun 30 22:04:22 2023 -0700

    [SPARK-44060][SQL] Code-gen for build side outer shuffled hash join
    
    ### What changes were proposed in this pull request?
    Codegen of shuffled hash join of build side outer join (ie, left outer join 
build left or right outer join build right)
    
     ### Why are the changes needed?
    The implementation of https://github.com/apache/spark/pull/41398 was only 
for non-codegen version, and codegen was disabled in this scenario.
    
     ### Does this PR introduce _any_ user-facing change?
    No
    
     ### How was this patch tested?
    New unit test in WholeStageCodegenSuite
    
    Closes #41614 from szehon-ho/same_side_outer_join_codegen_master.
    
    Authored-by: Szehon Ho <szehon.apa...@gmail.com>
    Signed-off-by: huaxingao <huaxin_...@apple.com>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |   9 ++
 .../sql/execution/joins/ShuffledHashJoinExec.scala |  68 ++++++----
 .../scala/org/apache/spark/sql/JoinSuite.scala     | 146 +++++++++++----------
 .../sql/execution/WholeStageCodegenSuite.scala     |  89 +++++++++++++
 4 files changed, 217 insertions(+), 95 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d60f5d170e7..270508139e4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2182,6 +2182,15 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN =
+    buildConf("spark.sql.codegen.join.buildSideOuterShuffledHashJoin.enabled")
+      .internal()
+      .doc("When true, enable code-gen for an OUTER shuffled hash join where 
outer side" +
+        " is the build side.")
+      .version("3.5.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val ENABLE_FULL_OUTER_SORT_MERGE_JOIN_CODEGEN =
     buildConf("spark.sql.codegen.join.fullOuterSortMergeJoin.enabled")
       .internal()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 8953bf19f35..974f6f9e50c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -340,8 +340,10 @@ case class ShuffledHashJoinExec(
 
   override def supportCodegen: Boolean = joinType match {
     case FullOuter => 
conf.getConf(SQLConf.ENABLE_FULL_OUTER_SHUFFLED_HASH_JOIN_CODEGEN)
-    case LeftOuter if buildSide == BuildLeft => false
-    case RightOuter if buildSide == BuildRight => false
+    case LeftOuter if buildSide == BuildLeft =>
+      conf.getConf(SQLConf.ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN)
+    case RightOuter if buildSide == BuildRight =>
+      conf.getConf(SQLConf.ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN)
     case _ => true
   }
 
@@ -362,9 +364,15 @@ case class ShuffledHashJoinExec(
   }
 
   override def doProduce(ctx: CodegenContext): String = {
-    // Specialize `doProduce` code for full outer join, because full outer 
join needs to
-    // iterate streamed and build side separately.
-    if (joinType != FullOuter) {
+    // Specialize `doProduce` code for full outer join and build-side outer 
join,
+    // because we need to iterate streamed and build side separately.
+    val specializedProduce = joinType match {
+      case FullOuter => true
+      case LeftOuter if buildSide == BuildLeft => true
+      case RightOuter if buildSide == BuildRight => true
+      case _ => false
+    }
+    if (!specializedProduce) {
       return super.doProduce(ctx)
     }
 
@@ -407,21 +415,24 @@ case class ShuffledHashJoinExec(
       case BuildLeft => buildResultVars ++ streamedResultVars
       case BuildRight => streamedResultVars ++ buildResultVars
     }
-    val consumeFullOuterJoinRow = ctx.freshName("consumeFullOuterJoinRow")
-    ctx.addNewFunction(consumeFullOuterJoinRow,
+    val consumeOuterJoinRow = ctx.freshName("consumeOuterJoinRow")
+    ctx.addNewFunction(consumeOuterJoinRow,
       s"""
-         |private void $consumeFullOuterJoinRow() throws java.io.IOException {
+         |private void $consumeOuterJoinRow() throws java.io.IOException {
          |  ${metricTerm(ctx, "numOutputRows")}.add(1);
          |  ${consume(ctx, resultVars)}
          |}
        """.stripMargin)
 
-    val joinWithUniqueKey = codegenFullOuterJoinWithUniqueKey(
+    val isFullOuterJoin = joinType == FullOuter
+    val joinWithUniqueKey = codegenBuildSideOrFullOuterJoinWithUniqueKey(
       ctx, (streamedRow, buildRow), (streamedInput, buildInput), 
streamedKeyEv, streamedKeyAnyNull,
-      streamedKeyExprCode.value, relationTerm, conditionCheck, 
consumeFullOuterJoinRow)
-    val joinWithNonUniqueKey = codegenFullOuterJoinWithNonUniqueKey(
+      streamedKeyExprCode.value, relationTerm, conditionCheck, 
consumeOuterJoinRow,
+      isFullOuterJoin)
+    val joinWithNonUniqueKey = codegenBuildSideOrFullOuterJoinNonUniqueKey(
       ctx, (streamedRow, buildRow), (streamedInput, buildInput), 
streamedKeyEv, streamedKeyAnyNull,
-      streamedKeyExprCode.value, relationTerm, conditionCheck, 
consumeFullOuterJoinRow)
+      streamedKeyExprCode.value, relationTerm, conditionCheck, 
consumeOuterJoinRow,
+      isFullOuterJoin)
 
     s"""
        |if ($keyIsUnique) {
@@ -433,10 +444,10 @@ case class ShuffledHashJoinExec(
   }
 
   /**
-   * Generates the code for full outer join with unique join keys.
-   * This is code-gen version of `fullOuterJoinWithUniqueKey()`.
+   * Generates the code for build-side or full outer join with unique join 
keys.
+   * This is code-gen version of `buildSideOrFullOuterJoinUniqueKey()`.
    */
-  private def codegenFullOuterJoinWithUniqueKey(
+  private def codegenBuildSideOrFullOuterJoinWithUniqueKey(
       ctx: CodegenContext,
       rows: (String, String),
       inputs: (String, String),
@@ -445,7 +456,8 @@ case class ShuffledHashJoinExec(
       streamedKeyValue: ExprValue,
       relationTerm: String,
       conditionCheck: String,
-      consumeFullOuterJoinRow: String): String = {
+      consumeOuterJoinRow: String,
+      isFullOuterJoin: Boolean): String = {
     // Inline mutable state since not many join operations in a task
     val matchedKeySetClsName = classOf[BitSet].getName
     val matchedKeySet = ctx.addMutableState(matchedKeySetClsName, 
"matchedKeySet",
@@ -484,7 +496,10 @@ case class ShuffledHashJoinExec(
          |    }
          |  }
          |
-         |  $consumeFullOuterJoinRow();
+         |  if ($foundMatch || $isFullOuterJoin) {
+         |    $consumeOuterJoinRow();
+         |  }
+         |
          |  if (shouldStop()) return;
          |}
        """.stripMargin
@@ -500,7 +515,7 @@ case class ShuffledHashJoinExec(
          |  // check if key index is not in matched keys set
          |  if (!$matchedKeySet.get($rowWithIndex.getKeyIndex())) {
          |    $buildRow = $rowWithIndex.getValue();
-         |    $consumeFullOuterJoinRow();
+         |    $consumeOuterJoinRow();
          |  }
          |
          |  if (shouldStop()) return;
@@ -514,10 +529,10 @@ case class ShuffledHashJoinExec(
   }
 
   /**
-   * Generates the code for full outer join with non-unique join keys.
-   * This is code-gen version of `fullOuterJoinWithNonUniqueKey()`.
+   * Generates the code for build-side or full outer join with non-unique join 
keys.
+   * This is code-gen version of `buildSideOrFullOuterJoinNonUniqueKey()`.
    */
-  private def codegenFullOuterJoinWithNonUniqueKey(
+  private def codegenBuildSideOrFullOuterJoinNonUniqueKey(
       ctx: CodegenContext,
       rows: (String, String),
       inputs: (String, String),
@@ -526,7 +541,8 @@ case class ShuffledHashJoinExec(
       streamedKeyValue: ExprValue,
       relationTerm: String,
       conditionCheck: String,
-      consumeFullOuterJoinRow: String): String = {
+      consumeOuterJoinRow: String,
+      isFullOuterJoin: Boolean): String = {
     // Inline mutable state since not many join operations in a task
     val matchedRowSetClsName = classOf[OpenHashSet[_]].getName
     val matchedRowSet = ctx.addMutableState(matchedRowSetClsName, 
"matchedRowSet",
@@ -572,13 +588,15 @@ case class ShuffledHashJoinExec(
          |      // set row index in matched row set
          |      $matchedRowSet.add($rowIndex);
          |      $foundMatch = true;
-         |      $consumeFullOuterJoinRow();
+         |      $consumeOuterJoinRow();
          |    }
          |  }
          |
          |  if (!$foundMatch) {
          |    $buildRow = null;
-         |    $consumeFullOuterJoinRow();
+         |    if ($isFullOuterJoin) {
+         |      $consumeOuterJoinRow();
+         |    }
          |  }
          |
          |  if (shouldStop()) return;
@@ -603,7 +621,7 @@ case class ShuffledHashJoinExec(
          |  // check if row index is not in matched row set
          |  if (!$matchedRowSet.contains($rowIndex)) {
          |    $buildRow = $rowWithIndex.getValue();
-         |    $consumeFullOuterJoinRow();
+         |    $consumeOuterJoinRow();
          |  }
          |
          |  if (shouldStop()) return;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 4d0fd2e6513..eb58a77704e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -1315,78 +1315,84 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
 
   test("SPARK-36612: Support left outer join build left or right outer join 
build right in " +
     "shuffled hash join") {
-    val inputDFs = Seq(
-      // Test unique join key
-      (spark.range(10).selectExpr("id as k1"),
-        spark.range(30).selectExpr("id as k2"),
-        $"k1" === $"k2"),
-      // Test non-unique join key
-      (spark.range(10).selectExpr("id % 5 as k1"),
-        spark.range(30).selectExpr("id % 5 as k2"),
-        $"k1" === $"k2"),
-      // Test empty build side
-      (spark.range(10).selectExpr("id as k1").filter("k1 < -1"),
-        spark.range(30).selectExpr("id as k2"),
-        $"k1" === $"k2"),
-      // Test empty stream side
-      (spark.range(10).selectExpr("id as k1"),
-        spark.range(30).selectExpr("id as k2").filter("k2 < -1"),
-        $"k1" === $"k2"),
-      // Test empty build and stream side
-      (spark.range(10).selectExpr("id as k1").filter("k1 < -1"),
-        spark.range(30).selectExpr("id as k2").filter("k2 < -1"),
-        $"k1" === $"k2"),
-      // Test string join key
-      (spark.range(10).selectExpr("cast(id * 3 as string) as k1"),
-        spark.range(30).selectExpr("cast(id as string) as k2"),
-        $"k1" === $"k2"),
-      // Test build side at right
-      (spark.range(30).selectExpr("cast(id / 3 as string) as k1"),
-        spark.range(10).selectExpr("cast(id as string) as k2"),
-        $"k1" === $"k2"),
-      // Test NULL join key
-      (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr("value 
as k1"),
-        spark.range(30).map(i => if (i % 4 == 0) i else 
null).selectExpr("value as k2"),
-        $"k1" === $"k2"),
-      (spark.range(10).map(i => if (i % 3 == 0) i else null).selectExpr("value 
as k1"),
-        spark.range(30).map(i => if (i % 5 == 0) i else 
null).selectExpr("value as k2"),
-        $"k1" === $"k2"),
-      // Test multiple join keys
-      (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr(
-        "value as k1", "cast(value % 5 as short) as k2", "cast(value * 3 as 
long) as k3"),
-        spark.range(30).map(i => if (i % 3 == 0) i else null).selectExpr(
-          "value as k4", "cast(value % 5 as short) as k5", "cast(value * 3 as 
long) as k6"),
-        $"k1" === $"k4" && $"k2" === $"k5" && $"k3" === $"k6")
-    )
-
-    // test left outer with left side build
-    inputDFs.foreach { case (df1, df2, joinExprs) =>
-      val smjDF = df1.hint("SHUFFLE_MERGE").join(df2, joinExprs, "leftouter")
-      assert(collect(smjDF.queryExecution.executedPlan) {
-        case _: SortMergeJoinExec => true }.size === 1)
-      val smjResult = smjDF.collect()
-
-      val shjDF = df1.hint("SHUFFLE_HASH").join(df2, joinExprs, "leftouter")
-      assert(collect(shjDF.queryExecution.executedPlan) {
-        case _: ShuffledHashJoinExec => true
-      }.size === 1)
-      // Same result between shuffled hash join and sort merge join
-      checkAnswer(shjDF, smjResult)
-    }
+    Seq("true", "false").foreach{ codegen =>
+      
withSQLConf(SQLConf.ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN.key -> 
codegen) {
+        val inputDFs = Seq(
+          // Test unique join key
+          (spark.range(10).selectExpr("id as k1"),
+            spark.range(30).selectExpr("id as k2"),
+            $"k1" === $"k2"),
+          // Test non-unique join key
+          (spark.range(10).selectExpr("id % 5 as k1"),
+            spark.range(30).selectExpr("id % 5 as k2"),
+            $"k1" === $"k2"),
+          // Test empty build side
+          (spark.range(10).selectExpr("id as k1").filter("k1 < -1"),
+            spark.range(30).selectExpr("id as k2"),
+            $"k1" === $"k2"),
+          // Test empty stream side
+          (spark.range(10).selectExpr("id as k1"),
+            spark.range(30).selectExpr("id as k2").filter("k2 < -1"),
+            $"k1" === $"k2"),
+          // Test empty build and stream side
+          (spark.range(10).selectExpr("id as k1").filter("k1 < -1"),
+            spark.range(30).selectExpr("id as k2").filter("k2 < -1"),
+            $"k1" === $"k2"),
+          // Test string join key
+          (spark.range(10).selectExpr("cast(id * 3 as string) as k1"),
+            spark.range(30).selectExpr("cast(id as string) as k2"),
+            $"k1" === $"k2"),
+          // Test build side at right
+          (spark.range(30).selectExpr("cast(id / 3 as string) as k1"),
+            spark.range(10).selectExpr("cast(id as string) as k2"),
+            $"k1" === $"k2"),
+          // Test NULL join key
+          (spark.range(10).map(i => if (i % 2 == 0) i else 
null).selectExpr("value as k1"),
+            spark.range(30).map(i => if (i % 4 == 0) i else 
null).selectExpr("value as k2"),
+            $"k1" === $"k2"),
+          (spark.range(10).map(i => if (i % 3 == 0) i else 
null).selectExpr("value as k1"),
+            spark.range(30).map(i => if (i % 5 == 0) i else 
null).selectExpr("value as k2"),
+            $"k1" === $"k2"),
+          // Test multiple join keys
+          (spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr(
+            "value as k1", "cast(value % 5 as short) as k2", "cast(value * 3 
as long) as k3"),
+            spark.range(30).map(i => if (i % 3 == 0) i else null).selectExpr(
+              "value as k4", "cast(value % 5 as short) as k5", "cast(value * 3 
as long) as k6"),
+            $"k1" === $"k4" && $"k2" === $"k5" && $"k3" === $"k6")
+        )
 
-    // test right outer with right side build
-    inputDFs.foreach { case (df2, df1, joinExprs) =>
-      val smjDF = df2.join(df1.hint("SHUFFLE_MERGE"), joinExprs, "rightouter")
-      assert(collect(smjDF.queryExecution.executedPlan) {
-        case _: SortMergeJoinExec => true }.size === 1)
-      val smjResult = smjDF.collect()
+        // test left outer with left side build
+        inputDFs.foreach { case (df1, df2, joinExprs) =>
+          val smjDF = df1.hint("SHUFFLE_MERGE").join(df2, joinExprs, 
"leftouter")
+          assert(collect(smjDF.queryExecution.executedPlan) {
+            case _: SortMergeJoinExec => true
+          }.size === 1)
+          val smjResult = smjDF.collect()
+
+          val shjDF = df1.hint("SHUFFLE_HASH").join(df2, joinExprs, 
"leftouter")
+          assert(collect(shjDF.queryExecution.executedPlan) {
+            case _: ShuffledHashJoinExec => true
+          }.size === 1)
+          // Same result between shuffled hash join and sort merge join
+          checkAnswer(shjDF, smjResult)
+        }
 
-      val shjDF = df2.join(df1.hint("SHUFFLE_HASH"), joinExprs, "rightouter")
-      assert(collect(shjDF.queryExecution.executedPlan) {
-        case _: ShuffledHashJoinExec => true
-      }.size === 1)
-      // Same result between shuffled hash join and sort merge join
-      checkAnswer(shjDF, smjResult)
+        // test right outer with right side build
+        inputDFs.foreach { case (df2, df1, joinExprs) =>
+          val smjDF = df2.join(df1.hint("SHUFFLE_MERGE"), joinExprs, 
"rightouter")
+          assert(collect(smjDF.queryExecution.executedPlan) {
+            case _: SortMergeJoinExec => true
+          }.size === 1)
+          val smjResult = smjDF.collect()
+
+          val shjDF = df2.join(df1.hint("SHUFFLE_HASH"), joinExprs, 
"rightouter")
+          assert(collect(shjDF.queryExecution.executedPlan) {
+            case _: ShuffledHashJoinExec => true
+          }.size === 1)
+          // Same result between shuffled hash join and sort merge join
+          checkAnswer(shjDF, smjResult)
+        }
+      }
     }
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index ac710c32296..0aaeedd5f06 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -232,6 +232,95 @@ class WholeStageCodegenSuite extends QueryTest with 
SharedSparkSession
     }
   }
 
+
+  test("SPARK-44060 Code-gen for build side outer shuffled hash join") {
+    val df1 = spark.range(0, 5).select($"id".as("k1"))
+    val df2 = spark.range(1, 11).select($"id".as("k2"))
+    val df3 = spark.range(2, 5).select($"id".as("k3"))
+
+    withSQLConf(SQLConf.ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN.key 
-> "true") {
+      Seq("SHUFFLE_HASH", "SHUFFLE_MERGE").foreach { hint =>
+        // test right join with unique key from build side
+        val rightJoinUniqueDf = df1.join(df2.hint(hint), $"k1" === $"k2", 
"right_outer")
+        assert(rightJoinUniqueDf.queryExecution.executedPlan.collect {
+          case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == 
"SHUFFLE_HASH" => true
+          case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == 
"SHUFFLE_MERGE" => true
+        }.size === 1)
+        checkAnswer(rightJoinUniqueDf, Seq(Row(1, 1), Row(2, 2), Row(3, 3), 
Row(4, 4),
+          Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9),
+          Row(null, 10)))
+        assert(rightJoinUniqueDf.count() === 10)
+
+        // test left join with unique key from build side
+        val leftJoinUniqueDf = df1.hint(hint).join(df2, $"k1" === $"k2", 
"left_outer")
+        assert(leftJoinUniqueDf.queryExecution.executedPlan.collect {
+          case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == 
"SHUFFLE_HASH" => true
+          case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == 
"SHUFFLE_MERGE" => true
+        }.size === 1)
+        checkAnswer(leftJoinUniqueDf, Seq(Row(0, null), Row(1, 1), Row(2, 2), 
Row(3, 3), Row(4, 4)))
+        assert(leftJoinUniqueDf.count() === 5)
+
+        // test right join with non-unique key from build side
+        val rightJoinNonUniqueDf = df1.join(df2.hint(hint), $"k1" === $"k2" % 
3, "right_outer")
+        assert(rightJoinNonUniqueDf.queryExecution.executedPlan.collect {
+          case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == 
"SHUFFLE_HASH" => true
+          case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == 
"SHUFFLE_MERGE" => true
+        }.size === 1)
+        checkAnswer(rightJoinNonUniqueDf, Seq(Row(0, 3), Row(0, 6), Row(0, 9), 
Row(1, 1),
+          Row(1, 4), Row(1, 7), Row(1, 10), Row(2, 2), Row(2, 5), Row(2, 8)))
+
+        // test left join with non-unique key from build side
+        val leftJoinNonUniqueDf = df1.hint(hint).join(df2, $"k1" === $"k2" % 
3, "left_outer")
+        assert(leftJoinNonUniqueDf.queryExecution.executedPlan.collect {
+          case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == 
"SHUFFLE_HASH" => true
+          case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == 
"SHUFFLE_MERGE" => true
+        }.size === 1)
+        checkAnswer(leftJoinNonUniqueDf, Seq(Row(0, 3), Row(0, 6), Row(0, 9), 
Row(1, 1),
+          Row(1, 4), Row(1, 7), Row(1, 10), Row(2, 2), Row(2, 5), Row(2, 8), 
Row(3, null),
+          Row(4, null)))
+
+        // test right join with non-equi condition
+        val rightJoinWithNonEquiDf = df1.join(df2.hint(hint),
+          $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "right_outer")
+        assert(rightJoinWithNonEquiDf.queryExecution.executedPlan.collect {
+          case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == 
"SHUFFLE_HASH" => true
+          case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == 
"SHUFFLE_MERGE" => true
+        }.size === 1)
+        checkAnswer(rightJoinWithNonEquiDf, Seq(Row(0, 6), Row(0, 9), Row(1, 
1), Row(1, 7),
+          Row(1, 10), Row(2, 2), Row(2, 8), Row(null, 3), Row(null, 4), 
Row(null, 5)))
+
+        // test left join with non-equi condition
+        val leftJoinWithNonEquiDf = df1.hint(hint).join(df2,
+          $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "left_outer")
+        assert(leftJoinWithNonEquiDf.queryExecution.executedPlan.collect {
+          case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == 
"SHUFFLE_HASH" => true
+          case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == 
"SHUFFLE_MERGE" => true
+        }.size === 1)
+        checkAnswer(leftJoinWithNonEquiDf, Seq(Row(0, 6), Row(0, 9), Row(1, 
1), Row(1, 7),
+          Row(1, 10), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null)))
+
+        // test two right joins
+        val twoRightJoinsDf = df1.join(df2.hint(hint), $"k1" === $"k2", 
"right_outer")
+          .join(df3.hint(hint), $"k1" === $"k3" && $"k1" + $"k3" =!= 2, 
"right_outer")
+        assert(twoRightJoinsDf.queryExecution.executedPlan.collect {
+          case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == 
"SHUFFLE_HASH" => true
+          case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == 
"SHUFFLE_MERGE" => true
+        }.size === 2)
+        checkAnswer(twoRightJoinsDf, Seq(Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 
4)))
+
+        // test two left joins
+        val twoLeftJoinsDf = df1.hint(hint).join(df2, $"k1" === $"k2", 
"left_outer").hint(hint)
+          .join(df3, $"k1" === $"k3" && $"k1" + $"k3" =!= 2, "left_outer")
+        assert(twoLeftJoinsDf.queryExecution.executedPlan.collect {
+          case WholeStageCodegenExec(_: ShuffledHashJoinExec) if hint == 
"SHUFFLE_HASH" => true
+          case WholeStageCodegenExec(_: SortMergeJoinExec) if hint == 
"SHUFFLE_MERGE" => true
+        }.size === 2)
+        checkAnswer(twoLeftJoinsDf,
+          Seq(Row(0, null, null), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, 3), 
Row(4, 4, 4)))
+      }
+    }
+  }
+
   test("Left/Right Outer SortMergeJoin should be included in 
WholeStageCodegen") {
     val df1 = spark.range(10).select($"id".as("k1"))
     val df2 = spark.range(4).select($"id".as("k2"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to