This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2ef60f72 [SPARK-35352][SQL] Add code-gen for full outer sort merge join 2ef60f72 is described below commit 2ef60f726c79349cbcda6f34f3e99b32951388bf Author: Cheng Su <chen...@fb.com> AuthorDate: Mon Nov 15 19:34:52 2021 +0800 [SPARK-35352][SQL] Add code-gen for full outer sort merge join ### What changes were proposed in this pull request? This PR is to add code-gen for FULL OUTER sort merge join. The change is in `SortMergeJoinExec.scala:codegenFullOuter()`. Followed the same algorithm in iterator mode - `SortMergeFullOuterJoinScanner`: maintain buffer for join left and right sides, and iterate over matched rows in buffers. Example query: ``` val df1 = spark.range(5).select($"id".as("k1")) val df2 = spark.range(10).select($"id".as("k2")) df1.join(df2.hint(hint), $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer") ``` Example generated code: https://gist.github.com/c21/5cab9751f24ae448d77a259d28cb77d7 In addition, to help review as this PR triggers several TPCDS plan files change. The below files are having the real code change: * `SortMergeJoinExec.scala` * `WholeStageCodegenSuite.scala` All other files are auto-generated golden file plan changes for TPCDS queries. ### Why are the changes needed? Improve the run-time/CPU performance of FULL OUTER sort merge join. Micro benchmark (same query in `JoinBenchmark.scala`): ``` def sortMergeJoin(): Unit = { val N = 2 << 20 codegenBenchmark("sort merge join", N) { val df1 = spark.range(N).selectExpr(s"id * 2 as k1") val df2 = spark.range(N).selectExpr(s"id * 3 as k2") val df = df1.join(df2, col("k1") === col("k2"), "full_outer") assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) df.noop() } } def sortMergeJoinWithDuplicates(): Unit = { val N = 2 << 20 codegenBenchmark("sort merge join with duplicates", N) { val df1 = spark.range(N) .selectExpr(s"(id * 15485863) % ${N*10} as k1") val df2 = spark.range(N) .selectExpr(s"(id * 15485867) % ${N*10} as k2") val df = df1.join(df2, col("k1") === col("k2"), "full_outer") assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) df.noop() } } ``` Seeing 20-30% of run-time improvement: ``` Running benchmark: sort merge join Running case: sort merge join wholestage off Stopped after 2 iterations, 2979 ms Running case: sort merge join wholestage on Stopped after 5 iterations, 5849 ms Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.16 Intel(R) Core(TM) i9-9980HK CPU 2.40GHz sort merge join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ sort merge join wholestage off 1453 1490 52 1.4 693.0 1.0X sort merge join wholestage on 1115 1170 43 1.9 531.6 1.3X Running benchmark: sort merge join with duplicates Running case: sort merge join with duplicates wholestage off Stopped after 2 iterations, 3236 ms Running case: sort merge join with duplicates wholestage on Stopped after 5 iterations, 6768 ms Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.16 Intel(R) Core(TM) i9-9980HK CPU 2.40GHz sort merge join with duplicates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------ sort merge join with duplicates wholestage off 1609 1618 13 1.3 767.2 1.0X sort merge join with duplicates wholestage on 1330 1354 24 1.6 634.4 1.2X ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? * Added unit test in `WholeStageCodegenSuite.scala`. * Existing unit test in `OuterJoinSuite.scala`. Closes #34581 from c21/smj-codegen. Authored-by: Cheng Su <chen...@fb.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/execution/joins/SortMergeJoinExec.scala | 256 ++++++++++++++++++++- .../approved-plans-v1_4/q51.sf100/explain.txt | 4 +- .../approved-plans-v1_4/q51.sf100/simplified.txt | 5 +- .../approved-plans-v1_4/q51/explain.txt | 4 +- .../approved-plans-v1_4/q51/simplified.txt | 5 +- .../approved-plans-v1_4/q97.sf100/explain.txt | 4 +- .../approved-plans-v1_4/q97.sf100/simplified.txt | 5 +- .../approved-plans-v1_4/q97/explain.txt | 4 +- .../approved-plans-v1_4/q97/simplified.txt | 5 +- .../approved-plans-v2_7/q51a.sf100/explain.txt | 4 +- .../approved-plans-v2_7/q51a.sf100/simplified.txt | 5 +- .../approved-plans-v2_7/q51a/explain.txt | 4 +- .../approved-plans-v2_7/q51a/simplified.txt | 5 +- .../sql/execution/WholeStageCodegenSuite.scala | 82 ++++--- 14 files changed, 327 insertions(+), 65 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 66054bf..afed14a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -364,7 +364,8 @@ case class SortMergeJoinExec( } private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => ((left, leftKeys), (right, rightKeys)) + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | FullOuter => + ((left, leftKeys), (right, rightKeys)) case RightOuter => ((right, rightKeys), (left, leftKeys)) case x => throw new IllegalArgumentException( @@ -374,9 +375,10 @@ case class SortMergeJoinExec( private lazy val streamedOutput = streamedPlan.output private lazy val bufferedOutput = bufferedPlan.output + // TODO(SPARK-37316): Add code-gen for existence sort merge join. override def supportCodegen: Boolean = joinType match { - case _: InnerLike | LeftOuter | RightOuter | LeftSemi | LeftAnti => true - case _ => false + case _: ExistenceJoin => false + case _ => true } override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -644,6 +646,12 @@ case class SortMergeJoinExec( override def needCopyResult: Boolean = true override def doProduce(ctx: CodegenContext): String = { + // Specialize `doProduce` code for full outer join, because full outer join needs to + // buffer both sides of join. + if (joinType == FullOuter) { + return codegenFullOuter(ctx) + } + // Inline mutable state since not many join operations in a task val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput", v => s"$v = inputs[0];", forceInline = true) @@ -890,6 +898,248 @@ case class SortMergeJoinExec( """.stripMargin } + /** + * Generates the code for Full Outer join. + */ + private def codegenFullOuter(ctx: CodegenContext): String = { + // Inline mutable state since not many join operations in a task. + // Create class member for input iterator from both sides. + val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", + v => s"$v = inputs[0];", forceInline = true) + val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", + v => s"$v = inputs[1];", forceInline = true) + + // Create class member for next input row from both sides. + val leftInputRow = ctx.addMutableState("InternalRow", "leftInputRow", forceInline = true) + val rightInputRow = ctx.addMutableState("InternalRow", "rightInputRow", forceInline = true) + + // Create variables for join keys from both sides. + val leftKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output) + val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") + val rightKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, right.output) + val rightAnyNull = rightKeyVars.map(_.isNull).mkString(" || ") + val matchedKeyVars = copyKeys(ctx, leftKeyVars) + val leftMatchedKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output) + val rightMatchedKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, right.output) + + // Create class member for next output row from both sides. + val leftOutputRow = ctx.addMutableState("InternalRow", "leftOutputRow", forceInline = true) + val rightOutputRow = ctx.addMutableState("InternalRow", "rightOutputRow", forceInline = true) + + // Create class member for buffers of rows with same join keys from both sides. + val bufferClsName = "java.util.ArrayList<InternalRow>" + val leftBuffer = ctx.addMutableState(bufferClsName, "leftBuffer", + v => s"$v = new $bufferClsName();", forceInline = true) + val rightBuffer = ctx.addMutableState(bufferClsName, "rightBuffer", + v => s"$v = new $bufferClsName();", forceInline = true) + val matchedClsName = classOf[BitSet].getName + val leftMatched = ctx.addMutableState(matchedClsName, "leftMatched", + v => s"$v = new $matchedClsName(1);", forceInline = true) + val rightMatched = ctx.addMutableState(matchedClsName, "rightMatched", + v => s"$v = new $matchedClsName(1);", forceInline = true) + val leftIndex = ctx.freshName("leftIndex") + val rightIndex = ctx.freshName("rightIndex") + + // Generate code for join condition + val leftResultVars = genOneSideJoinVars( + ctx, leftOutputRow, left, setDefaultValue = true) + val rightResultVars = genOneSideJoinVars( + ctx, rightOutputRow, right, setDefaultValue = true) + val resultVars = leftResultVars ++ rightResultVars + val (_, conditionCheck, _) = + getJoinCondition(ctx, leftResultVars, left, right, Some(rightOutputRow)) + + // Generate code for result output in separate function, as we need to output result from + // multiple places in join code. + val consumeFullOuterJoinRow = ctx.freshName("consumeFullOuterJoinRow") + ctx.addNewFunction(consumeFullOuterJoinRow, + s""" + |private void $consumeFullOuterJoinRow() throws java.io.IOException { + | ${metricTerm(ctx, "numOutputRows")}.add(1); + | ${consume(ctx, resultVars)} + |} + """.stripMargin) + + // Handle the case when input row has no match. + val outputLeftNoMatch = + s""" + |$leftOutputRow = $leftInputRow; + |$rightOutputRow = null; + |$leftInputRow = null; + |$consumeFullOuterJoinRow(); + """.stripMargin + val outputRightNoMatch = + s""" + |$rightOutputRow = $rightInputRow; + |$leftOutputRow = null; + |$rightInputRow = null; + |$consumeFullOuterJoinRow(); + """.stripMargin + + // Generate a function to scan both sides to find rows with matched join keys. + // The matched rows from both sides are copied in buffers separately. This function assumes + // either non-empty `leftIter` and `rightIter`, or non-null `leftInputRow` and `rightInputRow`. + // + // The function has the following steps: + // - Step 1: Find the next `leftInputRow` and `rightInputRow` with non-null join keys. + // Output row with null join keys (`outputLeftNoMatch` and `outputRightNoMatch`). + // + // - Step 2: Compare and find next same join keys from between `leftInputRow` and + // `rightInputRow`. + // Output row with smaller join keys (`outputLeftNoMatch` and `outputRightNoMatch`). + // + // - Step 3: Buffer rows with same join keys from both sides into `leftBuffer` and + // `rightBuffer`. Reset bit sets for both buffers accordingly (`leftMatched` and + // `rightMatched`). + val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows") + ctx.addNewFunction(findNextJoinRowsFuncName, + s""" + |private void $findNextJoinRowsFuncName( + | scala.collection.Iterator leftIter, + | scala.collection.Iterator rightIter) throws java.io.IOException { + | int comp = 0; + | $leftBuffer.clear(); + | $rightBuffer.clear(); + | + | if ($leftInputRow == null) { + | $leftInputRow = (InternalRow) leftIter.next(); + | } + | if ($rightInputRow == null) { + | $rightInputRow = (InternalRow) rightIter.next(); + | } + | + | ${leftKeyVars.map(_.code).mkString("\n")} + | if ($leftAnyNull) { + | // The left row join key is null, join it with null row + | $outputLeftNoMatch + | return; + | } + | + | ${rightKeyVars.map(_.code).mkString("\n")} + | if ($rightAnyNull) { + | // The right row join key is null, join it with null row + | $outputRightNoMatch + | return; + | } + | + | ${genComparison(ctx, leftKeyVars, rightKeyVars)} + | if (comp < 0) { + | // The left row join key is smaller, join it with null row + | $outputLeftNoMatch + | return; + | } else if (comp > 0) { + | // The right row join key is smaller, join it with null row + | $outputRightNoMatch + | return; + | } + | + | ${matchedKeyVars.map(_.code).mkString("\n")} + | $leftBuffer.add($leftInputRow.copy()); + | $rightBuffer.add($rightInputRow.copy()); + | $leftInputRow = null; + | $rightInputRow = null; + | + | // Buffer rows from both sides with same join key + | while (leftIter.hasNext()) { + | $leftInputRow = (InternalRow) leftIter.next(); + | ${leftMatchedKeyVars.map(_.code).mkString("\n")} + | ${genComparison(ctx, leftMatchedKeyVars, matchedKeyVars)} + | if (comp == 0) { + | + | $leftBuffer.add($leftInputRow.copy()); + | $leftInputRow = null; + | } else { + | break; + | } + | } + | while (rightIter.hasNext()) { + | $rightInputRow = (InternalRow) rightIter.next(); + | ${rightMatchedKeyVars.map(_.code).mkString("\n")} + | ${genComparison(ctx, rightMatchedKeyVars, matchedKeyVars)} + | if (comp == 0) { + | $rightBuffer.add($rightInputRow.copy()); + | $rightInputRow = null; + | } else { + | break; + | } + | } + | + | // Reset bit sets of buffers accordingly + | if ($leftBuffer.size() <= $leftMatched.capacity()) { + | $leftMatched.clearUntil($leftBuffer.size()); + | } else { + | $leftMatched = new $matchedClsName($leftBuffer.size()); + | } + | if ($rightBuffer.size() <= $rightMatched.capacity()) { + | $rightMatched.clearUntil($rightBuffer.size()); + | } else { + | $rightMatched = new $matchedClsName($rightBuffer.size()); + | } + |} + """.stripMargin) + + // Scan the left and right buffers to find all matched rows. + val matchRowsInBuffer = + s""" + |int $leftIndex; + |int $rightIndex; + | + |for ($leftIndex = 0; $leftIndex < $leftBuffer.size(); $leftIndex++) { + | $leftOutputRow = (InternalRow) $leftBuffer.get($leftIndex); + | for ($rightIndex = 0; $rightIndex < $rightBuffer.size(); $rightIndex++) { + | $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex); + | $conditionCheck { + | $consumeFullOuterJoinRow(); + | $leftMatched.set($leftIndex); + | $rightMatched.set($rightIndex); + | } + | } + | + | if (!$leftMatched.get($leftIndex)) { + | + | $rightOutputRow = null; + | $consumeFullOuterJoinRow(); + | } + |} + | + |$leftOutputRow = null; + |for ($rightIndex = 0; $rightIndex < $rightBuffer.size(); $rightIndex++) { + | if (!$rightMatched.get($rightIndex)) { + | // The right row has never matched any left row, join it with null row + | $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex); + | $consumeFullOuterJoinRow(); + | } + |} + """.stripMargin + + s""" + |while (($leftInputRow != null || $leftInput.hasNext()) && + | ($rightInputRow != null || $rightInput.hasNext())) { + | $findNextJoinRowsFuncName($leftInput, $rightInput); + | $matchRowsInBuffer + | if (shouldStop()) return; + |} + | + |// The right iterator has no more rows, join left row with null + |while ($leftInputRow != null || $leftInput.hasNext()) { + | if ($leftInputRow == null) { + | $leftInputRow = (InternalRow) $leftInput.next(); + | } + | $outputLeftNoMatch + | if (shouldStop()) return; + |} + | + |// The left iterator has no more rows, join right row with null + |while ($rightInputRow != null || $rightInput.hasNext()) { + | if ($rightInputRow == null) { + | $rightInputRow = (InternalRow) $rightInput.next(); + | } + | $outputRightNoMatch + | if (shouldStop()) return; + |} + """.stripMargin + } + override protected def withNewChildrenInternal( newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec = copy(left = newLeft, right = newRight) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/explain.txt index 51b1ae5..cbb189e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/explain.txt @@ -5,7 +5,7 @@ TakeOrderedAndProject (37) +- * Sort (34) +- Exchange (33) +- * Project (32) - +- SortMergeJoin FullOuter (31) + +- * SortMergeJoin FullOuter (31) :- * Sort (15) : +- Exchange (14) : +- * Project (13) @@ -176,7 +176,7 @@ Arguments: hashpartitioning(item_sk#25, d_date#20, 5), ENSURE_REQUIREMENTS, [id= Input [3]: [item_sk#25, d_date#20, cume_sales#28] Arguments: [item_sk#25 ASC NULLS FIRST, d_date#20 ASC NULLS FIRST], false, 0 -(31) SortMergeJoin +(31) SortMergeJoin [codegen id : 13] Left keys [2]: [item_sk#11, d_date#6] Right keys [2]: [item_sk#25, d_date#20] Join condition: None diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/simplified.txt index 38d3f50..489aab1 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51.sf100/simplified.txt @@ -9,8 +9,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store Exchange [item_sk] #1 WholeStageCodegen (13) Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales] - InputAdapter - SortMergeJoin [item_sk,d_date,item_sk,d_date] + SortMergeJoin [item_sk,d_date,item_sk,d_date] + InputAdapter WholeStageCodegen (6) Sort [item_sk,d_date] InputAdapter @@ -45,6 +45,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq] InputAdapter ReusedExchange [d_date_sk,d_date] #5 + InputAdapter WholeStageCodegen (12) Sort [item_sk,d_date] InputAdapter diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/explain.txt index 51b1ae5..cbb189e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/explain.txt @@ -5,7 +5,7 @@ TakeOrderedAndProject (37) +- * Sort (34) +- Exchange (33) +- * Project (32) - +- SortMergeJoin FullOuter (31) + +- * SortMergeJoin FullOuter (31) :- * Sort (15) : +- Exchange (14) : +- * Project (13) @@ -176,7 +176,7 @@ Arguments: hashpartitioning(item_sk#25, d_date#20, 5), ENSURE_REQUIREMENTS, [id= Input [3]: [item_sk#25, d_date#20, cume_sales#28] Arguments: [item_sk#25 ASC NULLS FIRST, d_date#20 ASC NULLS FIRST], false, 0 -(31) SortMergeJoin +(31) SortMergeJoin [codegen id : 13] Left keys [2]: [item_sk#11, d_date#6] Right keys [2]: [item_sk#25, d_date#20] Join condition: None diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/simplified.txt index 38d3f50..489aab1 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q51/simplified.txt @@ -9,8 +9,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store Exchange [item_sk] #1 WholeStageCodegen (13) Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales] - InputAdapter - SortMergeJoin [item_sk,d_date,item_sk,d_date] + SortMergeJoin [item_sk,d_date,item_sk,d_date] + InputAdapter WholeStageCodegen (6) Sort [item_sk,d_date] InputAdapter @@ -45,6 +45,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq] InputAdapter ReusedExchange [d_date_sk,d_date] #5 + InputAdapter WholeStageCodegen (12) Sort [item_sk,d_date] InputAdapter diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/explain.txt index e9e97e9..e47aaf2 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/explain.txt @@ -3,7 +3,7 @@ +- Exchange (22) +- * HashAggregate (21) +- * Project (20) - +- SortMergeJoin FullOuter (19) + +- * SortMergeJoin FullOuter (19) :- * Sort (9) : +- * HashAggregate (8) : +- Exchange (7) @@ -112,7 +112,7 @@ Results [2]: [cs_bill_customer_sk#9 AS customer_sk#14, cs_item_sk#10 AS item_sk# Input [2]: [customer_sk#14, item_sk#15] Arguments: [customer_sk#14 ASC NULLS FIRST, item_sk#15 ASC NULLS FIRST], false, 0 -(19) SortMergeJoin +(19) SortMergeJoin [codegen id : 7] Left keys [2]: [customer_sk#7, item_sk#8] Right keys [2]: [customer_sk#14, item_sk#15] Join condition: None diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/simplified.txt index 227b3c6..99c8a1d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97.sf100/simplified.txt @@ -5,8 +5,8 @@ WholeStageCodegen (8) WholeStageCodegen (7) HashAggregate [customer_sk,customer_sk] [sum,sum,sum,sum,sum,sum] Project [customer_sk,customer_sk] - InputAdapter - SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk] + SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk] + InputAdapter WholeStageCodegen (3) Sort [customer_sk,item_sk] HashAggregate [ss_customer_sk,ss_item_sk] [customer_sk,item_sk] @@ -29,6 +29,7 @@ WholeStageCodegen (8) Scan parquet default.date_dim [d_date_sk,d_month_seq] InputAdapter ReusedExchange [d_date_sk] #3 + InputAdapter WholeStageCodegen (6) Sort [customer_sk,item_sk] HashAggregate [cs_bill_customer_sk,cs_item_sk] [customer_sk,item_sk] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/explain.txt index e9e97e9..e47aaf2 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/explain.txt @@ -3,7 +3,7 @@ +- Exchange (22) +- * HashAggregate (21) +- * Project (20) - +- SortMergeJoin FullOuter (19) + +- * SortMergeJoin FullOuter (19) :- * Sort (9) : +- * HashAggregate (8) : +- Exchange (7) @@ -112,7 +112,7 @@ Results [2]: [cs_bill_customer_sk#9 AS customer_sk#14, cs_item_sk#10 AS item_sk# Input [2]: [customer_sk#14, item_sk#15] Arguments: [customer_sk#14 ASC NULLS FIRST, item_sk#15 ASC NULLS FIRST], false, 0 -(19) SortMergeJoin +(19) SortMergeJoin [codegen id : 7] Left keys [2]: [customer_sk#7, item_sk#8] Right keys [2]: [customer_sk#14, item_sk#15] Join condition: None diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/simplified.txt index 227b3c6..99c8a1d 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q97/simplified.txt @@ -5,8 +5,8 @@ WholeStageCodegen (8) WholeStageCodegen (7) HashAggregate [customer_sk,customer_sk] [sum,sum,sum,sum,sum,sum] Project [customer_sk,customer_sk] - InputAdapter - SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk] + SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk] + InputAdapter WholeStageCodegen (3) Sort [customer_sk,item_sk] HashAggregate [ss_customer_sk,ss_item_sk] [customer_sk,item_sk] @@ -29,6 +29,7 @@ WholeStageCodegen (8) Scan parquet default.date_dim [d_date_sk,d_month_seq] InputAdapter ReusedExchange [d_date_sk] #3 + InputAdapter WholeStageCodegen (6) Sort [customer_sk,item_sk] HashAggregate [cs_bill_customer_sk,cs_item_sk] [customer_sk,item_sk] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt index 740ea0f..64111ee 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/explain.txt @@ -10,7 +10,7 @@ TakeOrderedAndProject (70) : +- Exchange (58) : +- * Project (57) : +- * Filter (56) - : +- SortMergeJoin FullOuter (55) + : +- * SortMergeJoin FullOuter (55) : :- * Sort (27) : : +- Exchange (26) : : +- * HashAggregate (25) @@ -317,7 +317,7 @@ Arguments: hashpartitioning(item_sk#38, d_date#33, 5), ENSURE_REQUIREMENTS, [id= Input [3]: [item_sk#38, d_date#33, cume_sales#54] Arguments: [item_sk#38 ASC NULLS FIRST, d_date#33 ASC NULLS FIRST], false, 0 -(55) SortMergeJoin +(55) SortMergeJoin [codegen id : 29] Left keys [2]: [item_sk#11, d_date#6] Right keys [2]: [item_sk#38, d_date#33] Join condition: None diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt index ed52e97..1a89b7c 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a.sf100/simplified.txt @@ -14,8 +14,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store WholeStageCodegen (29) Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales] Filter [item_sk,item_sk] - InputAdapter - SortMergeJoin [item_sk,d_date,item_sk,d_date] + SortMergeJoin [item_sk,d_date,item_sk,d_date] + InputAdapter WholeStageCodegen (14) Sort [item_sk,d_date] InputAdapter @@ -73,6 +73,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store Sort [ws_item_sk,d_date] InputAdapter ReusedExchange [item_sk,d_date,sumws,ws_item_sk] #4 + InputAdapter WholeStageCodegen (28) Sort [item_sk,d_date] InputAdapter diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/explain.txt index cf86cd6..9edb377 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/explain.txt @@ -10,7 +10,7 @@ TakeOrderedAndProject (67) : +- Exchange (54) : +- * Project (53) : +- * Filter (52) - : +- SortMergeJoin FullOuter (51) + : +- * SortMergeJoin FullOuter (51) : :- * Sort (25) : : +- Exchange (24) : : +- * HashAggregate (23) @@ -298,7 +298,7 @@ Arguments: hashpartitioning(item_sk#38, d_date#33, 5), ENSURE_REQUIREMENTS, [id= Input [3]: [item_sk#38, d_date#33, cume_sales#54] Arguments: [item_sk#38 ASC NULLS FIRST, d_date#33 ASC NULLS FIRST], false, 0 -(51) SortMergeJoin +(51) SortMergeJoin [codegen id : 25] Left keys [2]: [item_sk#11, d_date#6] Right keys [2]: [item_sk#38, d_date#33] Join condition: None diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/simplified.txt index a99caa4..d6612db 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q51a/simplified.txt @@ -14,8 +14,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store WholeStageCodegen (25) Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales] Filter [item_sk,item_sk] - InputAdapter - SortMergeJoin [item_sk,d_date,item_sk,d_date] + SortMergeJoin [item_sk,d_date,item_sk,d_date] + InputAdapter WholeStageCodegen (12) Sort [item_sk,d_date] InputAdapter @@ -67,6 +67,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store Sort [ws_item_sk,d_date] InputAdapter ReusedExchange [item_sk,d_date,sumws,ws_item_sk] #4 + InputAdapter WholeStageCodegen (24) Sort [item_sk,d_date] InputAdapter diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index f483971..00ea371 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -171,48 +171,54 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4))) } - test("Full Outer ShuffledHashJoin should be included in WholeStageCodegen") { + test("Full Outer ShuffledHashJoin and SortMergeJoin should be included in WholeStageCodegen") { val df1 = spark.range(5).select($"id".as("k1")) val df2 = spark.range(10).select($"id".as("k2")) val df3 = spark.range(3).select($"id".as("k3")) - // test one join with unique key from build side - val joinUniqueDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", "full_outer") - assert(joinUniqueDF.queryExecution.executedPlan.collect { - case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true - }.size === 1) - checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4), - Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9))) - assert(joinUniqueDF.count() === 10) - - // test one join with non-unique key from build side - val joinNonUniqueDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2" % 3, "full_outer") - assert(joinNonUniqueDF.queryExecution.executedPlan.collect { - case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true - }.size === 1) - checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1), - Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null))) - - // test one join with non-equi condition - val joinWithNonEquiDF = df1.join(df2.hint("SHUFFLE_HASH"), - $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer") - assert(joinWithNonEquiDF.queryExecution.executedPlan.collect { - case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true - }.size === 1) - checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1), - Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4), - Row(null, 5))) - - // test two joins - val twoJoinsDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", "full_outer") - .join(df3.hint("SHUFFLE_HASH"), $"k1" === $"k3" && $"k1" + $"k3" =!= 2, "full_outer") - assert(twoJoinsDF.queryExecution.executedPlan.collect { - case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true - }.size === 2) - checkAnswer(twoJoinsDF, - Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null), - Row(null, 5, null), Row(null, 6, null), Row(null, 7, null), Row(null, 8, null), - Row(null, 9, null), Row(null, null, 1))) + Seq("SHUFFLE_HASH", "SHUFFLE_MERGE").foreach { hint => + // test one join with unique key from build side + val joinUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2", "full_outer") + assert(joinUniqueDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true + case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true + }.size === 1) + checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4), + Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9))) + assert(joinUniqueDF.count() === 10) + + // test one join with non-unique key from build side + val joinNonUniqueDF = df1.join(df2.hint(hint), $"k1" === $"k2" % 3, "full_outer") + assert(joinNonUniqueDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true + case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true + }.size === 1) + checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1), + Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null))) + + // test one join with non-equi condition + val joinWithNonEquiDF = df1.join(df2.hint(hint), + $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer") + assert(joinWithNonEquiDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true + case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true + }.size === 1) + checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1), + Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4), + Row(null, 5))) + + // test two joins + val twoJoinsDF = df1.join(df2.hint(hint), $"k1" === $"k2", "full_outer") + .join(df3.hint(hint), $"k1" === $"k3" && $"k1" + $"k3" =!= 2, "full_outer") + assert(twoJoinsDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true + case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true + }.size === 2) + checkAnswer(twoJoinsDF, + Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null), + Row(null, 5, null), Row(null, 6, null), Row(null, 7, null), Row(null, 8, null), + Row(null, 9, null), Row(null, null, 1))) + } } test("Left/Right Outer SortMergeJoin should be included in WholeStageCodegen") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org