Repository: spark
Updated Branches:
  refs/heads/master 5bd5e1b9c -> eab39f79e


[SPARK-25755][SQL][TEST] Supplementation of non-CodeGen unit tested for 
BroadcastHashJoinExec

## What changes were proposed in this pull request?

Currently, the BroadcastHashJoinExec physical plan supports CodeGen and 
non-codegen, but only CodeGen code is tested in the unit tests of 
InnerJoinSuite、OuterJoinSuite、ExistenceJoinSuite, and non-codegen code is 
not tested. This PR supplements this part of the test.

## How was this patch tested?

add new unit tested.

Closes #22755 from heary-cao/AddTestToBroadcastHashJoinExec.

Authored-by: caoxuewen <cao.xue...@zte.com.cn>
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eab39f79
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eab39f79
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eab39f79

Branch: refs/heads/master
Commit: eab39f79e4c2fb51266ff5844114ee56b8ec2d91
Parents: 5bd5e1b
Author: caoxuewen <cao.xue...@zte.com.cn>
Authored: Tue Oct 30 20:13:18 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Oct 30 20:13:18 2018 +0800

----------------------------------------------------------------------
 .../spark/sql/DataFrameAggregateSuite.scala     | 30 ++++----
 .../spark/sql/DataFrameFunctionsSuite.scala     | 15 ++--
 .../apache/spark/sql/DataFrameRangeSuite.scala  | 76 +++++++++-----------
 .../columnar/InMemoryColumnarQuerySuite.scala   | 39 +++++-----
 .../execution/joins/ExistenceJoinSuite.scala    |  2 +-
 .../sql/execution/joins/InnerJoinSuite.scala    |  6 +-
 .../sql/execution/joins/OuterJoinSuite.scala    |  2 +-
 .../apache/spark/sql/test/SQLTestUtils.scala    | 15 ++++
 8 files changed, 90 insertions(+), 95 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/eab39f79/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index d0106c4..d9ba6e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -669,23 +669,19 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
     }
   }
 
-  Seq(true, false).foreach { codegen =>
-    test("SPARK-22951: dropDuplicates on empty dataFrames should produce 
correct aggregate " +
-      s"results when codegen is enabled: $codegen") {
-      withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegen.toString)) {
-        // explicit global aggregations
-        val emptyAgg = Map.empty[String, String]
-        checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row()))
-        checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row()))
-        checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), 
Seq(Row(0)))
-        checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), 
Seq(Row()))
-        
checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), 
Seq(Row()))
-        
checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), 
Seq(Row(0)))
-
-        // global aggregation is converted to grouping aggregation:
-        assert(spark.emptyDataFrame.dropDuplicates().count() == 0)
-      }
-    }
+  testWithWholeStageCodegenOnAndOff("SPARK-22951: dropDuplicates on empty 
dataFrames " +
+    "should produce correct aggregate") { _ =>
+    // explicit global aggregations
+    val emptyAgg = Map.empty[String, String]
+    checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row()))
+    checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row()))
+    checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), Seq(Row(0)))
+    checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), 
Seq(Row()))
+    checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), 
Seq(Row()))
+    
checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), 
Seq(Row(0)))
+
+    // global aggregation is converted to grouping aggregation:
+    assert(spark.emptyDataFrame.dropDuplicates().count() == 0)
   }
 
   test("SPARK-21896: Window functions inside aggregate functions") {

http://git-wip-us.apache.org/repos/asf/spark/blob/eab39f79/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 60ebc5e..666ba35 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -458,15 +458,12 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8)
   }
 
-  test("SPARK-24633: arrays_zip splits input processing correctly") {
-    Seq("true", "false").foreach { wholestageCodegenEnabled =>
-      withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> 
wholestageCodegenEnabled) {
-        val df = spark.range(1)
-        val exprs = (0 to 5).map(x => array($"id" + lit(x)))
-        checkAnswer(df.select(arrays_zip(exprs: _*)),
-          Row(Seq(Row(0, 1, 2, 3, 4, 5))))
-      }
-    }
+  testWithWholeStageCodegenOnAndOff("SPARK-24633: arrays_zip splits input " +
+    "processing correctly") { _ =>
+    val df = spark.range(1)
+    val exprs = (0 to 5).map(x => array($"id" + lit(x)))
+    checkAnswer(df.select(arrays_zip(exprs: _*)),
+      Row(Seq(Row(0, 1, 2, 3, 4, 5))))
   }
 
   def testSizeOfMap(sizeOfNull: Any): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/eab39f79/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
index b0b4664..8cc7020 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
@@ -31,7 +31,6 @@ import org.apache.spark.sql.test.SharedSQLContext
 
 
 class DataFrameRangeSuite extends QueryTest with SharedSQLContext with 
Eventually {
-  import testImplicits._
 
   test("SPARK-7150 range api") {
     // numSlice is greater than length
@@ -107,7 +106,7 @@ class DataFrameRangeSuite extends QueryTest with 
SharedSQLContext with Eventuall
     assert(res17.collect === (1 to 10).map(i => Row(i)).toArray)
   }
 
-  test("Range with randomized parameters") {
+  testWithWholeStageCodegenOnAndOff("Range with randomized parameters") { 
codegenEnabled =>
     val MAX_NUM_STEPS = 10L * 1000
 
     val seed = System.currentTimeMillis()
@@ -133,25 +132,21 @@ class DataFrameRangeSuite extends QueryTest with 
SharedSQLContext with Eventuall
       val expCount = (start until end by step).size
       val expSum = (start until end by step).sum
 
-      for (codegen <- List(false, true)) {
-        withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> 
codegen.toString()) {
-          val res = spark.range(start, end, step, partitions).toDF("id").
-            agg(count("id"), sum("id")).collect()
-
-          withClue(s"seed = $seed start = $start end = $end step = $step 
partitions = " +
-              s"$partitions codegen = $codegen") {
-            assert(!res.isEmpty)
-            assert(res.head.getLong(0) == expCount)
-            if (expCount > 0) {
-              assert(res.head.getLong(1) == expSum)
-            }
-          }
+      val res = spark.range(start, end, step, partitions).toDF("id").
+        agg(count("id"), sum("id")).collect()
+
+      withClue(s"seed = $seed start = $start end = $end step = $step 
partitions = " +
+        s"$partitions codegen = $codegenEnabled") {
+        assert(!res.isEmpty)
+        assert(res.head.getLong(0) == expCount)
+        if (expCount > 0) {
+          assert(res.head.getLong(1) == expSum)
         }
       }
     }
   }
 
-  test("Cancelling stage in a query with Range.") {
+  testWithWholeStageCodegenOnAndOff("Cancelling stage in a query with Range.") 
{ _ =>
     val listener = new SparkListener {
       override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
         sparkContext.cancelStage(taskStart.stageId)
@@ -159,27 +154,25 @@ class DataFrameRangeSuite extends QueryTest with 
SharedSQLContext with Eventuall
     }
 
     sparkContext.addSparkListener(listener)
-    for (codegen <- Seq(true, false)) {
-      withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> 
codegen.toString()) {
-        val ex = intercept[SparkException] {
-          spark.range(0, 100000000000L, 1, 1)
-            .toDF("id").agg(sum("id")).collect()
-        }
-        ex.getCause() match {
-          case null =>
-            assert(ex.getMessage().contains("cancelled"))
-          case cause: SparkException =>
-            assert(cause.getMessage().contains("cancelled"))
-          case cause: Throwable =>
-            fail("Expected the cause to be SparkException, got " + 
cause.toString() + " instead.")
-        }
-      }
-      // Wait until all ListenerBus events consumed to make sure cancelStage 
called for all stages
-      sparkContext.listenerBus.waitUntilEmpty(20.seconds.toMillis)
-      eventually(timeout(20.seconds)) {
-        
assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum 
== 0)
-      }
+    val ex = intercept[SparkException] {
+      spark.range(0, 100000000000L, 1, 1)
+        .toDF("id").agg(sum("id")).collect()
+    }
+    ex.getCause() match {
+      case null =>
+        assert(ex.getMessage().contains("cancelled"))
+      case cause: SparkException =>
+        assert(cause.getMessage().contains("cancelled"))
+      case cause: Throwable =>
+        fail("Expected the cause to be SparkException, got " + 
cause.toString() + " instead.")
     }
+
+    // Wait until all ListenerBus events consumed to make sure cancelStage 
called for all stages
+    sparkContext.listenerBus.waitUntilEmpty(20.seconds.toMillis)
+    eventually(timeout(20.seconds)) {
+      
assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum 
== 0)
+    }
+
     sparkContext.removeSparkListener(listener)
   }
 
@@ -189,14 +182,11 @@ class DataFrameRangeSuite extends QueryTest with 
SharedSQLContext with Eventuall
     }
   }
 
-  test("SPARK-21041 SparkSession.range()'s behavior is inconsistent with 
SparkContext.range()") {
+  testWithWholeStageCodegenOnAndOff("SPARK-21041 SparkSession.range()'s 
behavior is " +
+    "inconsistent with SparkContext.range()") { _ =>
     val start = java.lang.Long.MAX_VALUE - 3
     val end = java.lang.Long.MIN_VALUE + 2
-    Seq("false", "true").foreach { value =>
-      withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value) {
-        assert(spark.range(start, end, 1).collect.length == 0)
-        assert(spark.range(start, start, 1).collect.length == 0)
-      }
-    }
+    assert(spark.range(start, end, 1).collect.length == 0)
+    assert(spark.range(start, start, 1).collect.length == 0)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/eab39f79/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index b1b23e4..e1567d0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -463,29 +463,26 @@ class InMemoryColumnarQuerySuite extends QueryTest with 
SharedSQLContext {
     assert(tableScanExec.partitionFilters.isEmpty)
   }
 
-  test("SPARK-22348: table cache should do partition batch pruning") {
-    Seq("true", "false").foreach { enabled =>
-      withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> enabled) {
-        val df1 = Seq((1, 1), (1, 1), (2, 2)).toDF("x", "y")
-        df1.unpersist()
-        df1.cache()
-
-        // Push predicate to the cached table.
-        val df2 = df1.where("y = 3")
-
-        val planBeforeFilter = df2.queryExecution.executedPlan.collect {
-          case f: FilterExec => f.child
-        }
-        assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec])
+  testWithWholeStageCodegenOnAndOff("SPARK-22348: table cache " +
+    "should do partition batch pruning") { codegenEnabled =>
+    val df1 = Seq((1, 1), (1, 1), (2, 2)).toDF("x", "y")
+    df1.unpersist()
+    df1.cache()
 
-        val execPlan = if (enabled == "true") {
-          WholeStageCodegenExec(planBeforeFilter.head)(codegenStageId = 0)
-        } else {
-          planBeforeFilter.head
-        }
-        assert(execPlan.executeCollectPublic().length == 0)
-      }
+    // Push predicate to the cached table.
+    val df2 = df1.where("y = 3")
+
+    val planBeforeFilter = df2.queryExecution.executedPlan.collect {
+      case f: FilterExec => f.child
+    }
+    assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec])
+
+    val execPlan = if (codegenEnabled == "true") {
+      WholeStageCodegenExec(planBeforeFilter.head)(codegenStageId = 0)
+    } else {
+      planBeforeFilter.head
     }
+    assert(execPlan.executeCollectPublic().length == 0)
   }
 
   test("SPARK-25727 - otherCopyArgs in InMemoryRelation does not include 
outputOrdering") {

http://git-wip-us.apache.org/repos/asf/spark/blob/eab39f79/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index 3837716..22279a3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -120,7 +120,7 @@ class ExistenceJoinSuite extends SparkPlanTest with 
SharedSQLContext {
       }
     }
 
-    test(s"$testName using BroadcastHashJoin") {
+    testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") { 
_ =>
       extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: 
SparkPlan) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/eab39f79/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 4408ece..f5edd6b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -127,7 +127,7 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
       EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin)
     }
 
-    test(s"$testName using BroadcastHashJoin (build=left)") {
+    testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin 
(build=left)") { _ =>
       extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: 
SparkPlan) =>
@@ -139,7 +139,7 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
       }
     }
 
-    test(s"$testName using BroadcastHashJoin (build=right)") {
+    testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin 
(build=right)") { _ =>
       extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: 
SparkPlan) =>
@@ -175,7 +175,7 @@ class InnerJoinSuite extends SparkPlanTest with 
SharedSQLContext {
       }
     }
 
-    test(s"$testName using SortMergeJoin") {
+    testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ =>
       extractJoinParts().foreach { case (_, leftKeys, rightKeys, 
boundCondition, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: 
SparkPlan) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/eab39f79/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 001feb0..513248d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -93,7 +93,7 @@ class OuterJoinSuite extends SparkPlanTest with 
SharedSQLContext {
     }
 
     if (joinType != FullOuter) {
-      test(s"$testName using BroadcastHashJoin") {
+      testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") 
{ _ =>
         val buildSide = joinType match {
           case LeftOuter => BuildRight
           case RightOuter => BuildLeft

http://git-wip-us.apache.org/repos/asf/spark/blob/eab39f79/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 6b03d1e..2341949 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.FilterExec
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.UninterruptibleThread
 import org.apache.spark.util.Utils
 
@@ -66,6 +67,20 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with 
SQLTestUtilsBase with
   }
 
   /**
+   * A helper function for turning off/on codegen.
+   */
+  protected def testWithWholeStageCodegenOnAndOff(testName: String)(f: String 
=> Unit): Unit = {
+    Seq("false", "true").foreach { codegenEnabled =>
+      val isTurnOn = if (codegenEnabled == "true") "on" else "off"
+      test(s"$testName (whole-stage-codegen ${isTurnOn})") {
+        withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled) {
+          f(codegenEnabled)
+        }
+      }
+    }
+  }
+
+  /**
    * Materialize the test data immediately after the `SQLContext` is set up.
    * This is necessary if the data is accessed by name but not through direct 
reference.
    */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to