This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 78cc91c [SPARK-32567][SQL] Add code-gen for full outer shuffled hash join 78cc91c is described below commit 78cc91c962abd48d7ec2e9721d1e1429f802dced Author: Cheng Su <chen...@fb.com> AuthorDate: Wed Nov 3 11:18:12 2021 +0800 [SPARK-32567][SQL] Add code-gen for full outer shuffled hash join ### What changes were proposed in this pull request? As title. This PR is to add code-gen support for FULL OUTER shuffled hash join. The main change is in `ShuffledHashJoinExec.scala:doProduce()` to generate code for FULL OUTER join. * `ShuffledHashJoinExec.scala:codegenFullOuterJoinWithUniqueKey()` is the code for join with unique join key from build side. * `ShuffledHashJoinExec.scala:codegenFullOuterJoinWithNonUniqueKey()` is the code for join with non-unique key. Example query: ``` val df1 = spark.range(5).select($"id".as("k1")) val df2 = spark.range(10).select($"id".as("k2")) df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer") ``` Generated code for example query: https://gist.github.com/c21/828b782ee81827f4148939cb50314a7b ### Why are the changes needed? Improve query performance for FULL OUTER shuffled hash join. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? * Added unit test in `WholeStageCodegenSuite`. * Existing unit test in `OuterJoinSuite`. Closes #34444 from c21/shj-codegen. Authored-by: Cheng Su <chen...@fb.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../joins/BroadcastNestedLoopJoinExec.scala | 2 +- .../spark/sql/execution/joins/HashJoin.scala | 4 +- .../sql/execution/joins/JoinCodegenSupport.scala | 50 ++-- .../sql/execution/joins/ShuffledHashJoinExec.scala | 260 ++++++++++++++++++++- .../sql/execution/joins/SortMergeJoinExec.scala | 3 +- .../sql/execution/WholeStageCodegenSuite.scala | 45 +++- .../spark/sql/execution/joins/OuterJoinSuite.scala | 21 ++ 7 files changed, 349 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 77a30b7..0677211 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -463,7 +463,7 @@ case class BroadcastNestedLoopJoinExec( private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (buildRowArray, buildRowArrayTerm) = prepareBroadcast(ctx) val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast) - val buildVars = genBuildSideVars(ctx, buildRow, broadcast) + val buildVars = genOneSideJoinVars(ctx, buildRow, broadcast, setDefaultValue = true) val resultVars = buildSide match { case BuildLeft => buildVars ++ input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index f87acb8..0e8bb84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -444,7 +444,7 @@ trait HashJoin extends JoinCodegenSupport { val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val matched = ctx.freshName("matched") - val buildVars = genBuildSideVars(ctx, matched, buildPlan) + val buildVars = genOneSideJoinVars(ctx, matched, buildPlan, setDefaultValue = true) val numOutput = metricTerm(ctx, "numOutputRows") // filter the output via condition @@ -646,7 +646,7 @@ trait HashJoin extends JoinCodegenSupport { val existsVar = ctx.freshName("exists") val matched = ctx.freshName("matched") - val buildVars = genBuildSideVars(ctx, matched, buildPlan) + val buildVars = genOneSideJoinVars(ctx, matched, buildPlan, setDefaultValue = false) val checkCondition = if (condition.isDefined) { val expr = condition.get // evaluate the variables from build side that used by condition diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala index 96aa0be..75f0a35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, InnerLike, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} /** @@ -30,7 +29,7 @@ trait JoinCodegenSupport extends CodegenSupport with BaseJoinExec { /** * Generate the (non-equi) condition used to filter joined rows. - * This is used in Inner, Left Semi and Left Anti joins. + * This is used in Inner, Left Semi, Left Anti and Full Outer joins. * * @return Tuple of variable name for row of build side, generated code for condition, * and generated code for variables of build side. @@ -39,13 +38,15 @@ trait JoinCodegenSupport extends CodegenSupport with BaseJoinExec { ctx: CodegenContext, streamVars: Seq[ExprCode], streamPlan: SparkPlan, - buildPlan: SparkPlan): (String, String, Seq[ExprCode]) = { - val buildRow = ctx.freshName("buildRow") - val buildVars = genBuildSideVars(ctx, buildRow, buildPlan) + buildPlan: SparkPlan, + buildRow: Option[String] = None): (String, String, Seq[ExprCode]) = { + val buildSideRow = buildRow.getOrElse(ctx.freshName("buildRow")) + val buildVars = genOneSideJoinVars(ctx, buildSideRow, buildPlan, setDefaultValue = false) val checkCondition = if (condition.isDefined) { val expr = condition.get // evaluate the variables from build side that used by condition val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + // filter the output via condition ctx.currentVars = streamVars ++ buildVars val ev = @@ -59,41 +60,38 @@ trait JoinCodegenSupport extends CodegenSupport with BaseJoinExec { } else { "" } - (buildRow, checkCondition, buildVars) + (buildSideRow, checkCondition, buildVars) } /** - * Generates the code for variables of build side. + * Generates the code for variables of one child side of join. */ - protected def genBuildSideVars( + protected def genOneSideJoinVars( ctx: CodegenContext, - buildRow: String, - buildPlan: SparkPlan): Seq[ExprCode] = { + row: String, + plan: SparkPlan, + setDefaultValue: Boolean): Seq[ExprCode] = { ctx.currentVars = null - ctx.INPUT_ROW = buildRow - buildPlan.output.zipWithIndex.map { case (a, i) => + ctx.INPUT_ROW = row + plan.output.zipWithIndex.map { case (a, i) => val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) - joinType match { - case _: InnerLike | LeftSemi | LeftAnti | _: ExistenceJoin => - ev - case LeftOuter | RightOuter => - // the variables are needed even there is no matched rows - val isNull = ctx.freshName("isNull") - val value = ctx.freshName("value") - val javaType = CodeGenerator.javaType(a.dataType) - val code = code""" + if (setDefaultValue) { + // the variables are needed even there is no matched rows + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val javaType = CodeGenerator.javaType(a.dataType) + val code = code""" |boolean $isNull = true; |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; - |if ($buildRow != null) { + |if ($row != null) { | ${ev.code} | $isNull = ${ev.isNull}; | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)) - case _ => - throw new IllegalArgumentException( - s"JoinCodegenSupport.genBuildSideVars should not take $joinType as the JoinType") + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)) + } else { + ev } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 47b2bd2..7136229 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -311,11 +311,6 @@ case class ShuffledHashJoinExec( streamResultIter ++ buildResultIter } - // TODO(SPARK-32567): support full outer shuffled hash join code-gen - override def supportCodegen: Boolean = { - joinType != FullOuter - } - override def inputRDDs(): Seq[RDD[InternalRow]] = { streamedPlan.execute() :: buildPlan.execute() :: Nil } @@ -332,6 +327,261 @@ case class ShuffledHashJoinExec( HashedRelationInfo(relationTerm, keyIsUnique = false, isEmpty = false) } + override def doProduce(ctx: CodegenContext): String = { + // Specialize `doProduce` code for full outer join, because full outer join needs to + // iterate streamed and build side separately. + if (joinType != FullOuter) { + return super.doProduce(ctx) + } + + val HashedRelationInfo(relationTerm, _, _) = prepareRelation(ctx) + + // Inline mutable state since not many join operations in a task + val keyIsUnique = ctx.addMutableState("boolean", "keyIsUnique", + v => s"$v = $relationTerm.keyIsUnique();", forceInline = true) + val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput", + v => s"$v = inputs[0];", forceInline = true) + val buildInput = ctx.addMutableState("scala.collection.Iterator", "buildInput", + v => s"$v = $relationTerm.valuesWithKeyIndex();", forceInline = true) + val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", forceInline = true) + val buildRow = ctx.addMutableState("InternalRow", "buildRow", forceInline = true) + + // Generate variables and related code from streamed side + val streamedVars = genOneSideJoinVars(ctx, streamedRow, streamedPlan, setDefaultValue = false) + val streamedKeyVariables = evaluateRequiredVariables(streamedOutput, streamedVars, + AttributeSet.fromAttributeSets(streamedKeys.map(_.references))) + ctx.currentVars = streamedVars + val streamedKeyExprCode = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys) + val streamedKeyEv = + s""" + |$streamedKeyVariables + |${streamedKeyExprCode.code} + """.stripMargin + val streamedKeyAnyNull = s"${streamedKeyExprCode.value}.anyNull()" + + // Generate code for join condition + val (_, conditionCheck, _) = + getJoinCondition(ctx, streamedVars, streamedPlan, buildPlan, Some(buildRow)) + + // Generate code for result output in separate function, as we need to output result from + // multiple places in join code. + val streamedResultVars = genOneSideJoinVars( + ctx, streamedRow, streamedPlan, setDefaultValue = true) + val buildResultVars = genOneSideJoinVars( + ctx, buildRow, buildPlan, setDefaultValue = true) + val resultVars = buildSide match { + case BuildLeft => buildResultVars ++ streamedResultVars + case BuildRight => streamedResultVars ++ buildResultVars + } + val consumeFullOuterJoinRow = ctx.freshName("consumeFullOuterJoinRow") + ctx.addNewFunction(consumeFullOuterJoinRow, + s""" + |private void $consumeFullOuterJoinRow() { + | ${metricTerm(ctx, "numOutputRows")}.add(1); + | ${consume(ctx, resultVars)} + |} + """.stripMargin) + + val joinWithUniqueKey = codegenFullOuterJoinWithUniqueKey( + ctx, (streamedRow, buildRow), (streamedInput, buildInput), streamedKeyEv, streamedKeyAnyNull, + streamedKeyExprCode.value, relationTerm, conditionCheck, consumeFullOuterJoinRow) + val joinWithNonUniqueKey = codegenFullOuterJoinWithNonUniqueKey( + ctx, (streamedRow, buildRow), (streamedInput, buildInput), streamedKeyEv, streamedKeyAnyNull, + streamedKeyExprCode.value, relationTerm, conditionCheck, consumeFullOuterJoinRow) + + s""" + |if ($keyIsUnique) { + | $joinWithUniqueKey + |} else { + | $joinWithNonUniqueKey + |} + """.stripMargin + } + + /** + * Generates the code for full outer join with unique join keys. + * This is code-gen version of `fullOuterJoinWithUniqueKey()`. + */ + private def codegenFullOuterJoinWithUniqueKey( + ctx: CodegenContext, + rows: (String, String), + inputs: (String, String), + streamedKeyEv: String, + streamedKeyAnyNull: String, + streamedKeyValue: ExprValue, + relationTerm: String, + conditionCheck: String, + consumeFullOuterJoinRow: String): String = { + // Inline mutable state since not many join operations in a task + val matchedKeySetClsName = classOf[BitSet].getName + val matchedKeySet = ctx.addMutableState(matchedKeySetClsName, "matchedKeySet", + v => s"$v = new $matchedKeySetClsName($relationTerm.maxNumKeysIndex());", forceInline = true) + val rowWithIndexClsName = classOf[ValueRowWithKeyIndex].getName + val rowWithIndex = ctx.freshName("rowWithIndex") + val foundMatch = ctx.freshName("foundMatch") + val (streamedRow, buildRow) = rows + val (streamedInput, buildInput) = inputs + + val joinStreamSide = + s""" + |while ($streamedInput.hasNext()) { + | $streamedRow = (InternalRow) $streamedInput.next(); + | + | // generate join key for stream side + | $streamedKeyEv + | + | // find matches from HashedRelation + | boolean $foundMatch = false; + | $buildRow = null; + | $rowWithIndexClsName $rowWithIndex = $streamedKeyAnyNull ? null: + | $relationTerm.getValueWithKeyIndex($streamedKeyValue); + | + | if ($rowWithIndex != null) { + | $buildRow = $rowWithIndex.getValue(); + | // check join condition + | $conditionCheck { + | // set key index in matched keys set + | $matchedKeySet.set($rowWithIndex.getKeyIndex()); + | $foundMatch = true; + | } + | + | if (!$foundMatch) { + | $buildRow = null; + | } + | } + | + | $consumeFullOuterJoinRow(); + | if (shouldStop()) return; + |} + """.stripMargin + + val filterBuildSide = + s""" + |$streamedRow = null; + | + |// find non-matched rows from HashedRelation + |while ($buildInput.hasNext()) { + | $rowWithIndexClsName $rowWithIndex = ($rowWithIndexClsName) $buildInput.next(); + | + | // check if key index is not in matched keys set + | if (!$matchedKeySet.get($rowWithIndex.getKeyIndex())) { + | $buildRow = $rowWithIndex.getValue(); + | $consumeFullOuterJoinRow(); + | } + | + | if (shouldStop()) return; + |} + """.stripMargin + + s""" + |$joinStreamSide + |$filterBuildSide + """.stripMargin + } + + /** + * Generates the code for full outer join with non-unique join keys. + * This is code-gen version of `fullOuterJoinWithNonUniqueKey()`. + */ + private def codegenFullOuterJoinWithNonUniqueKey( + ctx: CodegenContext, + rows: (String, String), + inputs: (String, String), + streamedKeyEv: String, + streamedKeyAnyNull: String, + streamedKeyValue: ExprValue, + relationTerm: String, + conditionCheck: String, + consumeFullOuterJoinRow: String): String = { + // Inline mutable state since not many join operations in a task + val matchedRowSetClsName = classOf[OpenHashSet[_]].getName + val matchedRowSet = ctx.addMutableState(matchedRowSetClsName, "matchedRowSet", + v => s"$v = new $matchedRowSetClsName(scala.reflect.ClassTag$$.MODULE$$.Long());", + forceInline = true) + val prevKeyIndex = ctx.addMutableState("int", "prevKeyIndex", + v => s"$v = -1;", forceInline = true) + val valueIndex = ctx.addMutableState("int", "valueIndex", + v => s"$v = -1;", forceInline = true) + val rowWithIndexClsName = classOf[ValueRowWithKeyIndex].getName + val rowWithIndex = ctx.freshName("rowWithIndex") + val buildIterator = ctx.freshName("buildIterator") + val foundMatch = ctx.freshName("foundMatch") + val keyIndex = ctx.freshName("keyIndex") + val (streamedRow, buildRow) = rows + val (streamedInput, buildInput) = inputs + + val rowIndex = s"(((long)$keyIndex) << 32) | $valueIndex" + + val joinStreamSide = + s""" + |while ($streamedInput.hasNext()) { + | $streamedRow = (InternalRow) $streamedInput.next(); + | + | // generate join key for stream side + | $streamedKeyEv + | + | // find matches from HashedRelation + | boolean $foundMatch = false; + | $buildRow = null; + | scala.collection.Iterator $buildIterator = $streamedKeyAnyNull ? null: + | $relationTerm.getWithKeyIndex($streamedKeyValue); + | + | int $valueIndex = -1; + | while ($buildIterator != null && $buildIterator.hasNext()) { + | $rowWithIndexClsName $rowWithIndex = ($rowWithIndexClsName) $buildIterator.next(); + | int $keyIndex = $rowWithIndex.getKeyIndex(); + | $buildRow = $rowWithIndex.getValue(); + | $valueIndex++; + | + | // check join condition + | $conditionCheck { + | // set row index in matched row set + | $matchedRowSet.add($rowIndex); + | $foundMatch = true; + | $consumeFullOuterJoinRow(); + | } + | } + | + | if (!$foundMatch) { + | $buildRow = null; + | $consumeFullOuterJoinRow(); + | } + | + | if (shouldStop()) return; + |} + """.stripMargin + + val filterBuildSide = + s""" + |$streamedRow = null; + | + |// find non-matched rows from HashedRelation + |while ($buildInput.hasNext()) { + | $rowWithIndexClsName $rowWithIndex = ($rowWithIndexClsName) $buildInput.next(); + | int $keyIndex = $rowWithIndex.getKeyIndex(); + | if ($prevKeyIndex == -1 || $keyIndex != $prevKeyIndex) { + | $valueIndex = 0; + | $prevKeyIndex = $keyIndex; + | } else { + | $valueIndex += 1; + | } + | + | // check if row index is not in matched row set + | if (!$matchedRowSet.contains($rowIndex)) { + | $buildRow = $rowWithIndex.getValue(); + | $consumeFullOuterJoinRow(); + | } + | + | if (shouldStop()) return; + |} + """.stripMargin + + s""" + |$joinStreamSide + |$filterBuildSide + """.stripMargin + } + override protected def withNewChildrenInternal( newLeft: SparkPlan, newRight: SparkPlan): ShuffledHashJoinExec = copy(left = newLeft, right = newRight) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 18f584b..66054bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -655,7 +655,8 @@ case class SortMergeJoinExec( // Create variables for row from both sides. val (streamedVars, streamedVarDecl) = createStreamedVars(ctx, streamedRow) val bufferedRow = ctx.freshName("bufferedRow") - val bufferedVars = genBuildSideVars(ctx, bufferedRow, bufferedPlan) + val setDefaultValue = joinType == LeftOuter || joinType == RightOuter + val bufferedVars = genOneSideJoinVars(ctx, bufferedRow, bufferedPlan, setDefaultValue) val iterator = ctx.freshName("iterator") val numOutput = metricTerm(ctx, "numOutputRows") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 6cc6e33..7da813c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -149,7 +149,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) } - test("ShuffledHashJoin should be included in WholeStageCodegen") { + test("Inner ShuffledHashJoin should be included in WholeStageCodegen") { val df1 = spark.range(5).select($"id".as("k1")) val df2 = spark.range(15).select($"id".as("k2")) val df3 = spark.range(6).select($"id".as("k3")) @@ -171,6 +171,49 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4))) } + test("Full Outer ShuffledHashJoin should be included in WholeStageCodegen") { + val df1 = spark.range(5).select($"id".as("k1")) + val df2 = spark.range(10).select($"id".as("k2")) + val df3 = spark.range(3).select($"id".as("k3")) + + // test one join with unique key from build side + val joinUniqueDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", "full_outer") + assert(joinUniqueDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true + }.size === 1) + checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4), + Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9))) + + // test one join with non-unique key from build side + val joinNonUniqueDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2" % 3, "full_outer") + assert(joinNonUniqueDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true + }.size === 1) + checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1), + Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null))) + + // test one join with non-equi condition + val joinWithNonEquiDF = df1.join(df2.hint("SHUFFLE_HASH"), + $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer") + assert(joinWithNonEquiDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true + }.size === 1) + checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1), + Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4), + Row(null, 5))) + + // test two joins + val twoJoinsDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", "full_outer") + .join(df3.hint("SHUFFLE_HASH"), $"k1" === $"k3" && $"k1" + $"k3" =!= 2, "full_outer") + assert(twoJoinsDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true + }.size === 2) + checkAnswer(twoJoinsDF, + Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null), + Row(null, 5, null), Row(null, 6, null), Row(null, 7, null), Row(null, 8, null), + Row(null, 9, null), Row(null, null, 1))) + } + test("Left/Right Outer SortMergeJoin should be included in WholeStageCodegen") { val df1 = spark.range(10).select($"id".as("k1")) val df2 = spark.range(4).select($"id".as("k2")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 229d756..4f78833 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -304,4 +304,25 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { (null, null, 7, 7.0) ) ) + + testOuterJoin( + "full outer join with unique keys", + uniqueLeft, + uniqueRight, + FullOuter, + uniqueCondition, + Seq( + (null, null, null, null), + (null, null, null, null), + (1, 2.0, null, null), + (2, 1.0, 2, 3.0), + (3, 3.0, null, null), + (5, 1.0, 5, 3.0), + (6, 6.0, null, null), + (null, null, 0, 0.0), + (null, null, 3, 2.0), + (null, null, 4, 1.0), + (null, null, 7, 7.0) + ) + ) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org