This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 9389a2ccacce [SPARK-45072][CONNECT] Fix outer scopes for ammonite 
classes
9389a2ccacce is described below

commit 9389a2ccacce61cbbbc9bbb1b19b2825d932ba11
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Tue Sep 5 15:35:12 2023 +0200

    [SPARK-45072][CONNECT] Fix outer scopes for ammonite classes
    
    ### What changes were proposed in this pull request?
    Ammonite places all user code inside Helper classes which are nested inside 
the class it creates for each command. This PR adds a custom code class wrapper 
for the Ammonite REPL. It makes sure the Helper classes generated by ammonite 
are always registered as an outer scope immediately. This way we can 
instantiate classes defined inside the Helper class, even when we execute Spark 
code as part of the Helper's constructor.
    
    ### Why are the changes needed?
    When you currently define a class and execute a Spark command using that 
class inside the same cell/line this will fail with an NullPointerException. 
The reason for that is that we cannot resolve the outer scope needed to 
instantiate the class. This PR fixes that issue. The following code will now 
execute successfully (include the curly braces):
    ```scala
    {
      case class Thing(val value: String)
      val r = (0 to 10).map( value => Thing(value.toString) )
      spark.createDataFrame(r)
    }
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    I added more tests to the `ReplE2ESuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #42807 from hvanhovell/SPARK-45072.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit 40943c2748fdd28d970d017cb8ee86c294ee62df)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../apache/spark/sql/application/ConnectRepl.scala | 29 +++++++++++--
 .../spark/sql/application/ReplE2ESuite.scala       | 48 ++++++++++++++++++----
 .../CheckConnectJvmClientCompatibility.scala       |  6 +++
 3 files changed, 71 insertions(+), 12 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala
index e6ada566398c..0360a4057886 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala
@@ -22,7 +22,8 @@ import java.util.concurrent.Semaphore
 import scala.util.control.NonFatal
 
 import ammonite.compiler.CodeClassWrapper
-import ammonite.util.Bind
+import ammonite.compiler.iface.CodeWrapper
+import ammonite.util.{Bind, Imports, Name, Util}
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.SparkSession
@@ -94,8 +95,8 @@ object ConnectRepl {
     val main = ammonite.Main(
       welcomeBanner = Option(splash),
       predefCode = predefCode,
-      replCodeWrapper = CodeClassWrapper,
-      scriptCodeWrapper = CodeClassWrapper,
+      replCodeWrapper = ExtendedCodeClassWrapper,
+      scriptCodeWrapper = ExtendedCodeClassWrapper,
       inputStream = inputStream,
       outputStream = outputStream,
       errorStream = errorStream)
@@ -107,3 +108,25 @@ object ConnectRepl {
     }
   }
 }
+
+/**
+ * [[CodeWrapper]] that makes sure new Helper classes are always registered as 
an outer scope.
+ */
+@DeveloperApi
+object ExtendedCodeClassWrapper extends CodeWrapper {
+  override def wrapperPath: Seq[Name] = CodeClassWrapper.wrapperPath
+  override def apply(
+      code: String,
+      source: Util.CodeSource,
+      imports: Imports,
+      printCode: String,
+      indexedWrapper: Name,
+      extraCode: String): (String, String, Int) = {
+    val (top, bottom, level) =
+      CodeClassWrapper(code, source, imports, printCode, indexedWrapper, 
extraCode)
+    // Make sure we register the Helper before anything else, so outer scopes 
work as expected.
+    val augmentedTop = top +
+      
"\norg.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this)\n"
+    (augmentedTop, bottom, level)
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 4106d298dbe2..5bb8cbf3543b 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -79,12 +79,10 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
 
   override def afterEach(): Unit = {
     semaphore.drainPermits()
-    if (ammoniteOut != null) {
-      ammoniteOut.reset()
-    }
   }
 
   def runCommandsInShell(input: String): String = {
+    ammoniteOut.reset()
     require(input.nonEmpty)
     // Pad the input with a semaphore release so that we know when the 
execution of the provided
     // input is complete.
@@ -105,6 +103,10 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
     getCleanString(ammoniteOut)
   }
 
+  def runCommandsUsingSingleCellInShell(input: String): String = {
+    runCommandsInShell("{\n" + input + "\n}")
+  }
+
   def assertContains(message: String, output: String): Unit = {
     val isContain = output.contains(message)
     assert(
@@ -263,6 +265,31 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
     assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], 
[id3,25])", output)
   }
 
+  test("Single Cell Compilation") {
+    val input =
+      """
+        |case class C1(value: Int)
+        |case class C2(value: Int)
+        |val h1 = classOf[C1].getDeclaringClass
+        |val h2 = classOf[C2].getDeclaringClass
+        |val same = h1 == h2
+        |""".stripMargin
+    assertContains("same: Boolean = false", runCommandsInShell(input))
+    assertContains("same: Boolean = true", 
runCommandsUsingSingleCellInShell(input))
+  }
+
+  test("Local relation containing REPL generated class") {
+    val input =
+      """
+        |case class MyTestClass(value: Int)
+        |val data = (0 to 10).map(MyTestClass)
+        |spark.createDataset(data).map(mtc => 
mtc.value).select(sum($"value")).as[Long].head
+        |""".stripMargin
+    val expected = "Long = 55L"
+    assertContains(expected, runCommandsInShell(input))
+    assertContains(expected, runCommandsUsingSingleCellInShell(input))
+  }
+
   test("Collect REPL generated class") {
     val input =
       """
@@ -275,8 +302,9 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
         |  map(mtc => s"MyTestClass(${mtc.value})").
         |  mkString("[", ", ", "]")
           """.stripMargin
-    val output = runCommandsInShell(input)
-    assertContains("""String = "[MyTestClass(1), MyTestClass(3)]"""", output)
+    val expected = """String = "[MyTestClass(1), MyTestClass(3)]""""
+    assertContains(expected, runCommandsInShell(input))
+    assertContains(expected, runCommandsUsingSingleCellInShell(input))
   }
 
   test("REPL class in encoder") {
@@ -288,8 +316,9 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
         |  map(mtc => mtc.value).
         |  collect()
       """.stripMargin
-    val output = runCommandsInShell(input)
-    assertContains("Array[Int] = Array(0, 1, 2)", output)
+    val expected = "Array[Int] = Array(0, 1, 2)"
+    assertContains(expected, runCommandsInShell(input))
+    assertContains(expected, runCommandsUsingSingleCellInShell(input))
   }
 
   test("REPL class in UDF") {
@@ -301,8 +330,9 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
         |  map(mtc => s"MyTestClass(${mtc.value})").
         |  mkString("[", ", ", "]")
       """.stripMargin
-    val output = runCommandsInShell(input)
-    assertContains("""String = "[MyTestClass(0), MyTestClass(1)]"""", output)
+    val expected = """String = "[MyTestClass(0), MyTestClass(1)]""""
+    assertContains(expected, runCommandsInShell(input))
+    assertContains(expected, runCommandsUsingSingleCellInShell(input))
   }
 
   test("streaming works with REPL generated code") {
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 72b0f02f378d..0cc1a44b2732 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -358,6 +358,12 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[MissingClassProblem](
         "org.apache.spark.sql.application.ConnectRepl$" // developer API
       ),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.application.ExtendedCodeClassWrapper" // 
developer API
+      ),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.application.ExtendedCodeClassWrapper$" // 
developer API
+      ),
 
       // SparkSession
       // developer API


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to