This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8575d09f80b0 [SPARK-53522][SDP][TEST] Simplify PipelineTest
8575d09f80b0 is described below

commit 8575d09f80b0040a6e92cba1e26f70aff90a2c4a
Author: Wenchen Fan <[email protected]>
AuthorDate: Mon Sep 8 10:26:26 2025 -0700

    [SPARK-53522][SDP][TEST] Simplify PipelineTest
    
    ### What changes were proposed in this pull request?
    
    This PR simplifies `PipelineTest` by extending `QueryTest`, as many util 
functions are already defined there.
    
    ### Why are the changes needed?
    
    code cleanup
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #52266 from cloud-fan/sdp_test.
    
    Authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../scala/org/apache/spark/SparkFunSuite.scala     |   7 ++
 .../pipelines/HiveMaterializeTablesSuite.scala     |   3 -
 .../pipelines/graph/MaterializeTablesSuite.scala   |   3 +
 .../spark/sql/pipelines/utils/PipelineTest.scala   | 106 +++------------------
 4 files changed, 25 insertions(+), 94 deletions(-)

diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala 
b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
index cd421ba20bd7..3b0ec30ee86c 100644
--- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
@@ -242,6 +242,13 @@ abstract class SparkFunSuite
     }
   }
 
+  protected def namedGridTest[A](testNamePrefix: String, testTags: 
Tag*)(params: Map[String, A])(
+    testFun: A => Unit): Unit = {
+    for (param <- params) {
+      test(testNamePrefix + s" ${param._1}", testTags: _*)(testFun(param._2))
+    }
+  }
+
   /**
    * Creates a temporary directory, which is then passed to `f` and will be 
deleted after `f`
    * returns.
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/pipelines/HiveMaterializeTablesSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/pipelines/HiveMaterializeTablesSuite.scala
index f57125438ccf..73dbc98d57c9 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/pipelines/HiveMaterializeTablesSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/pipelines/HiveMaterializeTablesSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.hive.pipelines
 
-import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.pipelines.graph.MaterializeTablesSuite
 
@@ -37,6 +36,4 @@ class HiveMaterializeTablesSuite extends 
MaterializeTablesSuite with TestHiveSin
       super.afterEach()
     }
   }
-
-  override protected implicit def sqlContext: SQLContext = spark.sqlContext
 }
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
index 72e292ec5070..37c32a349866 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
@@ -262,6 +262,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
   test("invalid schema merge") {
     val session = spark
     import session.implicits._
+    implicit def sqlContext: org.apache.spark.sql.classic.SQLContext = 
spark.sqlContext
 
     val streamInts = MemoryStream[Int]
     streamInts.addData(1, 2)
@@ -329,6 +330,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
   test("specified schema incompatible with existing table") {
     val session = spark
     import session.implicits._
+    implicit def sqlContext: org.apache.spark.sql.classic.SQLContext = 
spark.sqlContext
 
     sql(s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t6(x 
BOOLEAN)")
     val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
@@ -627,6 +629,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
     ) {
       val session = spark
       import session.implicits._
+      implicit def sqlContext: org.apache.spark.sql.classic.SQLContext = 
spark.sqlContext
 
       val streamInts = MemoryStream[Int]
       streamInts.addData(1 until 5: _*)
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
index 8db5c0c626b3..54b324c182f2 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
@@ -25,37 +25,28 @@ import scala.util.{Failure, Try}
 import scala.util.control.NonFatal
 
 import org.scalactic.source
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Tag}
+import org.scalatest.Tag
 import org.scalatest.concurrent.Eventually
-import org.scalatest.matchers.should.Matchers
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{Column, QueryTest, Row, SQLContext, TypedColumn}
+import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, TypedColumn}
 import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.classic.SparkSession
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.pipelines.graph.{DataflowGraph, 
PipelineUpdateContextImpl, SqlGraphRegistrationContext}
 import org.apache.spark.sql.pipelines.utils.PipelineTest.{cleanupMetastore, 
createTempDir}
+import org.apache.spark.sql.test.SQLTestUtils
 
 abstract class PipelineTest
-    extends SparkFunSuite
-    with BeforeAndAfterAll
-    with BeforeAndAfterEach
-    with Matchers
-    with SparkErrorTestMixin
-    with TargetCatalogAndDatabaseMixin
-    with Logging
-    with Eventually {
+  extends QueryTest
+  with SQLTestUtils
+  with SparkErrorTestMixin
+  with TargetCatalogAndDatabaseMixin
+  with Logging
+  with Eventually {
 
   final protected val storageRoot = createTempDir()
 
-  protected def spark: SparkSession
-
-  protected implicit def sqlContext: SQLContext
-
-  def sql(text: String): DataFrame = spark.sql(text)
-
   protected def startPipelineAndWaitForCompletion(unresolvedDataflowGraph: 
DataflowGraph): Unit = {
     val updateContext = new PipelineUpdateContextImpl(
       unresolvedDataflowGraph, eventCallback = _ => ())
@@ -187,13 +178,6 @@ abstract class PipelineTest
     gridTest(testNamePrefix, paramName, testTags: _*)(Seq(true, 
false))(testFun)
   }
 
-  protected def namedGridTest[A](testNamePrefix: String, testTags: 
Tag*)(params: Map[String, A])(
-      testFun: A => Unit): Unit = {
-    for (param <- params) {
-      test(testNamePrefix + s" (${param._1})", testTags: _*)(testFun(param._2))
-    }
-  }
-
   protected def namedGridIgnore[A](testNamePrefix: String, testTags: 
Tag*)(params: Map[String, A])(
       testFun: A => Unit): Unit = {
     for (param <- params) {
@@ -221,14 +205,15 @@ abstract class PipelineTest
       df: => DataFrame,
       expectedAnswer: Seq[Row],
       checkPlan: Option[SparkPlan => Unit]): Unit = {
-    QueryTest.checkAnswer(df, expectedAnswer)
+    super.checkAnswer(df, expectedAnswer)
 
     // To help with test development, you can dump the plan to the log by 
passing
     // `--test_env=DUMP_PLAN=true` to `bazel test`.
+    val classicDf = df.asInstanceOf[org.apache.spark.sql.classic.DataFrame]
     if (Option(System.getenv("DUMP_PLAN")).exists(s => 
java.lang.Boolean.valueOf(s))) {
-      log.info(s"Spark plan:\n${df.queryExecution.executedPlan}")
+      log.info(s"Spark plan:\n${classicDf.queryExecution.executedPlan}")
     }
-    checkPlan.foreach(_.apply(df.queryExecution.executedPlan))
+    checkPlan.foreach(_.apply(classicDf.queryExecution.executedPlan))
   }
 
   /**
@@ -237,76 +222,15 @@ abstract class PipelineTest
    * @param df the `DataFrame` to be executed
    * @param expectedAnswer the expected result in a `Seq` of `Row`s.
    */
-  protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit 
= {
+  override protected def checkAnswer(df: => DataFrame, expectedAnswer: 
Seq[Row]): Unit = {
     checkAnswerAndPlan(df, expectedAnswer, None)
   }
 
-  protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = {
-    checkAnswer(df, Seq(expectedAnswer))
-  }
-
   case class ValidationArgs(
       ignoreFieldOrder: Boolean = false,
       ignoreFieldCase: Boolean = false
   )
 
-  /**
-   * Evaluates a dataset to make sure that the result of calling collect 
matches the given
-   * expected answer.
-   */
-  protected def checkDataset[T](ds: => Dataset[T], expectedAnswer: T*): Unit = 
{
-    val result = getResult(ds)
-
-    if (!QueryTest.compare(result.toSeq, expectedAnswer)) {
-      fail(s"""
-              |Decoded objects do not match expected objects:
-              |expected: $expectedAnswer
-              |actual:   ${result.toSeq}
-         """.stripMargin)
-    }
-  }
-
-  /**
-   * Evaluates a dataset to make sure that the result of calling collect 
matches the given
-   * expected answer, after sort.
-   */
-  protected def checkDatasetUnorderly[T: Ordering](result: Array[T], 
expectedAnswer: T*): Unit = {
-    if (!QueryTest.compare(result.toSeq.sorted, expectedAnswer.sorted)) {
-      fail(s"""
-              |Decoded objects do not match expected objects:
-              |expected: $expectedAnswer
-              |actual:   ${result.toSeq}
-         """.stripMargin)
-    }
-  }
-
-  protected def checkDatasetUnorderly[T: Ordering](ds: => Dataset[T], 
expectedAnswer: T*): Unit = {
-    val result = getResult(ds)
-    if (!QueryTest.compare(result.toSeq.sorted, expectedAnswer.sorted)) {
-      fail(s"""
-              |Decoded objects do not match expected objects:
-              |expected: $expectedAnswer
-              |actual:   ${result.toSeq}
-         """.stripMargin)
-    }
-  }
-
-  private def getResult[T](ds: => Dataset[T]): Array[T] = {
-    ds
-
-    try ds.collect()
-    catch {
-      case NonFatal(e) =>
-        fail(
-          s"""
-             |Exception collecting dataset as objects
-             |${ds.queryExecution}
-           """.stripMargin,
-          e
-        )
-    }
-  }
-
   /** Holds a parsed version along with the original json of a test. */
   private case class TestSequence(json: Seq[String], rows: Seq[Row]) {
     require(json.size == rows.size)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to