This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 02d1f09df0d [SPARK-44824][CONNECT][TESTS][3.5] Reset `ammoniteOut` in 
the `afterEach` method of `ReplE2ESuite`
02d1f09df0d is described below

commit 02d1f09df0da202e3996cdcfbca44525862528b9
Author: yangjie01 <yangji...@baidu.com>
AuthorDate: Wed Aug 16 21:03:01 2023 +0200

    [SPARK-44824][CONNECT][TESTS][3.5] Reset `ammoniteOut` in the `afterEach` 
method of `ReplE2ESuite`
    
    ### What changes were proposed in this pull request?
    This PR add `ammoniteOut.reset()` in the `afterEach` method of 
`ReplE2ESuite` to ensure that the 'output' used for assertions in each test 
case is only related to the current case and not all content.
    
    ### Why are the changes needed?
    The current `ammoniteOut` records the output content of all executed tests, 
without isolating between cases. This can lead to unexpected assertion results.
    For example, adding 'assertContains("""String = "[MyTestClass(1), 
MyTestClass(3)]"""", output)' in the following test case would still pass the 
test because it is a result content printed to `ammoniteOut` in the previous 
test case.
    
    
https://github.com/apache/spark/blob/2be20e54a2222f6cdf64e8486d1910133b43665f/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala#L283-L290
    
    Hence, we need to clear the content in `ammoniteOut` after each test to 
achieve isolation between test cases.
    
    ### Does this PR introduce _any_ user-facing change?
    No, just for test
    
    ### How was this patch tested?
    - Pass Github Actions
    - Manual check
    
    Prints the `output` after `val output = runCommandsInShell(input)` in the 
the case `streaming works with REPL generated code`
    
    
https://github.com/apache/spark/blob/2be20e54a2222f6cdf64e8486d1910133b43665f/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala#L313-L318
    
    run
    
    ```
    build/sbt "connect-client-jvm/testOnly 
org.apache.spark.sql.application.ReplE2ESuite" -Phive
    ```
    
    **Before**: we can see the content of all test cases that have been 
executed in the `ReplE2ESuite`
    
    ```
    Spark session available as 'spark'.
       _____                  __      ______                            __
      / ___/____  ____ ______/ /__   / ____/___  ____  ____  ___  _____/ /_
      \__ \/ __ \/ __ `/ ___/ //_/  / /   / __ \/ __ \/ __ \/ _ \/ ___/ __/
     ___/ / /_/ / /_/ / /  / ,<    / /___/ /_/ / / / / / / /  __/ /__/ /_
    /____/ .___/\__,_/_/  /_/|_|   \____/\____/_/ /_/_/ /_/\___/\___/\__/
        /_/
    
     spark.sql("select 1").collect()
    res0: Array[org.apache.spark.sql.Row] = Array([1])
    
     semaphore.release()
    
     class A(x: Int) { def get = x * 5 + 19 }
    defined class A
    
     def dummyUdf(x: Int): Int = new A(x).get
    defined function dummyUdf
    
     val myUdf = udf(dummyUdf _)
    myUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     spark.range(5).select(myUdf(col("id"))).as[Int].collect()
    res5: Array[Int] = Array(19, 24, 29, 34, 39)
    
     semaphore.release()
    
     class A(x: Int) { def get = x * 42 + 5 }
    defined class A
    
     val myUdf = udf((x: Int) => new A(x).get)
    myUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     spark.range(5).select(myUdf(col("id"))).as[Int].collect()
    res9: Array[Int] = Array(5, 47, 89, 131, 173)
    
     semaphore.release()
    
     class A(x: Int) { def get = x * 7 }
    defined class A
    
     val myUdf = udf((x: Int) => new A(x).get)
    myUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     val modifiedUdf = myUdf.withName("myUdf").asNondeterministic()
    modifiedUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     spark.range(5).select(modifiedUdf(col("id"))).as[Int].collect()
    res14: Array[Int] = Array(0, 7, 14, 21, 28)
    
     semaphore.release()
    
     spark.range(10).filter(n => n % 2 == 0).collect()
    res16: Array[java.lang.Long] = Array(0L, 2L, 4L, 6L, 8L)
    
     semaphore.release()
    
     import java.nio.file.Paths
    import java.nio.file.Paths
    
     def classLoadingTest(x: Int): Int = {
        val classloader =
          
Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader)
        val cls = Class.forName("com.example.Hello$", true, classloader)
        val module = cls.getField("MODULE$").get(null)
        cls.getMethod("test").invoke(module).asInstanceOf[Int]
      }
    defined function classLoadingTest
    
     val classLoaderUdf = udf(classLoadingTest _)
    classLoaderUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     val jarPath = 
Paths.get("/Users/yangjie01/SourceCode/git/spark-mine-sbt/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar").toUri
    jarPath: java.net.URI = 
file:///Users/yangjie01/SourceCode/git/spark-mine-sbt/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar
    
     spark.addArtifact(jarPath)
    
     spark.range(5).select(classLoaderUdf(col("id"))).as[Int].collect()
    res23: Array[Int] = Array(2, 2, 2, 2, 2)
    
     semaphore.release()
    
     import org.apache.spark.sql.api.java._
    import org.apache.spark.sql.api.java._
    
     import org.apache.spark.sql.types.LongType
    import org.apache.spark.sql.types.LongType
    
     val javaUdf = udf(new UDF1[Long, Long] {
        override def call(num: Long): Long = num * num + 25L
      }, LongType).asNondeterministic()
    javaUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     spark.range(5).select(javaUdf(col("id"))).as[Long].collect()
    res28: Array[Long] = Array(25L, 26L, 29L, 34L, 41L)
    
     semaphore.release()
    
     import org.apache.spark.sql.api.java._
    import org.apache.spark.sql.api.java._
    
     import org.apache.spark.sql.types.LongType
    import org.apache.spark.sql.types.LongType
    
     spark.udf.register("javaUdf", new UDF1[Long, Long] {
        override def call(num: Long): Long = num * num * num + 250L
      }, LongType)
    
     spark.sql("select javaUdf(id) from range(5)").as[Long].collect()
    res33: Array[Long] = Array(250L, 251L, 258L, 277L, 314L)
    
     semaphore.release()
    
     class A(x: Int) { def get = x * 100 }
    defined class A
    
     val myUdf = udf((x: Int) => new A(x).get)
    myUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     spark.udf.register("dummyUdf", myUdf)
    res37: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     spark.sql("select dummyUdf(id) from range(5)").as[Long].collect()
    res38: Array[Long] = Array(0L, 100L, 200L, 300L, 400L)
    
     semaphore.release()
    
     class A(x: Int) { def get = x * 15 }
    defined class A
    
     spark.udf.register("directUdf", (x: Int) => new A(x).get)
    res41: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     spark.sql("select directUdf(id) from range(5)").as[Long].collect()
    res42: Array[Long] = Array(0L, 15L, 30L, 45L, 60L)
    
     semaphore.release()
    
     val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
    df: org.apache.spark.sql.package.DataFrame = [id: string, value: int]
    
     spark.udf.register("simpleUDF", (v: Int) => v * v)
    res45: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     df.select($"id", call_udf("simpleUDF", $"value")).collect()
    res46: Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])
    
    semaphore.release()
    
     val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
    df: org.apache.spark.sql.package.DataFrame = [id: string, value: int]
    
     spark.udf.register("simpleUDF", (v: Int) => v * v)
    res49: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     df.select($"id", call_function("simpleUDF", $"value")).collect()
    res50: Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])
    
     semaphore.release()
    
     case class MyTestClass(value: Int)
    defined class MyTestClass
    
     spark.range(4).
        filter($"id" % 2 === 1).
        select($"id".cast("int").as("value")).
        as[MyTestClass].
        collect().
        map(mtc => s"MyTestClass(${mtc.value})").
        mkString("[", ", ", "]")
    res53: String = "[MyTestClass(1), MyTestClass(3)]"
    
     semaphore.release()
    
     case class MyTestClass(value: Int)
    defined class MyTestClass
    
     spark.range(2).map(i => MyTestClass(i.toInt)).collect()
    res56: Array[MyTestClass] = Array(MyTestClass(0), MyTestClass(1))
    
     semaphore.release()
    
     val add1 = udf((i: Long) => i + 1)
    add1: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     val query = {
        spark.readStream
            .format("rate")
            .option("rowsPerSecond", "10")
            .option("numPartitions", "1")
            .load()
            .withColumn("value", add1($"value"))
            .writeStream
            .format("memory")
            .queryName("my_sink")
            .start()
      }
    query: org.apache.spark.sql.streaming.StreamingQuery = 
org.apache.spark.sql.streaming.RemoteStreamingQuery79cdf37e
    
     var progress = query.lastProgress
    progress: org.apache.spark.sql.streaming.StreamingQueryProgress = null
    
     while (query.isActive && (progress == null || progress.numInputRows == 0)) 
{
        query.awaitTermination(100)
        progress = query.lastProgress
      }
    
     val noException = query.exception.isEmpty
    noException: Boolean = true
    
     query.stop()
    
     semaphore.release()
    ```
    
    **After**: we can only see the content that is related to the test case 
`streaming works with REPL generated code`
    
    ```
    
     val add1 = udf((i: Long) => i + 1)
    add1: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
      Array(
    ...
    
     val query = {
        spark.readStream
            .format("rate")
            .option("rowsPerSecond", "10")
            .option("numPartitions", "1")
            .load()
            .withColumn("value", add1($"value"))
            .writeStream
            .format("memory")
            .queryName("my_sink")
            .start()
      }
    query: org.apache.spark.sql.streaming.StreamingQuery = 
org.apache.spark.sql.streaming.RemoteStreamingQuery5429e19b
    
     var progress = query.lastProgress
    progress: org.apache.spark.sql.streaming.StreamingQueryProgress = null
    
     while (query.isActive && (progress == null || progress.numInputRows == 0)) 
{
        query.awaitTermination(100)
        progress = query.lastProgress
      }
    
     val noException = query.exception.isEmpty
    noException: Boolean = true
    
     query.stop()
    
     semaphore.release()
    
    ```
    
    Closes #42512 from LuciferYang/SPARK-44824-35.
    
    Authored-by: yangjie01 <yangji...@baidu.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala | 3 +++
 1 file changed, 3 insertions(+)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index f467aee73f2..b2971236147 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -83,6 +83,9 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
 
   override def afterEach(): Unit = {
     semaphore.drainPermits()
+    if (ammoniteOut != null) {
+      ammoniteOut.reset()
+    }
   }
 
   def runCommandsInShell(input: String): String = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to