This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new c230a5011a6 [SPARK-44974][CONNECT] Null out 
SparkSession/Dataset/KeyValueGroupedDatset on serialization
c230a5011a6 is described below

commit c230a5011a6d45c0f393833995b052930f11c324
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Mon Aug 28 15:05:18 2023 +0200

    [SPARK-44974][CONNECT] Null out SparkSession/Dataset/KeyValueGroupedDatset 
on serialization
    
    ### What changes were proposed in this pull request?
    This PR changes the serialization for connect `SparkSession`, `Dataset`, 
and `KeyValueGroupedDataset`. While these were marked as serializable they were 
not, because they refer to bits and pieces that are not serializable. Even if 
we were to fix this, then we still have a class clash problem with server side 
classes that have the same name, but have different structure. the latter can 
be fixed with serialization proxies, but I am going to hold that until someone 
actually needs/wants this.
    
    After this PR these classes are serialized as null. This is a somewhat 
suboptimal solution compared to throwing exceptions on serialization, however 
this is more compatible compared to the old situation, and makes accidental 
capture of these classes less of an issue for UDFs.
    
    ### Why are the changes needed?
    More compatible with the old situation. Improved UX when working with UDFs.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Added tests to `ClientDatasetSuite`, `KeyValueGroupedDatasetE2ETestSuite`, 
`SparkSessionSuite`, and `UserDefinedFunctionE2ETestSuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #42688 from hvanhovell/SPARK-44974.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit f0b04286022e0774d78b9adcf4aeabc181a3ec89)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../jvm/src/main/scala/org/apache/spark/sql/Dataset.scala |  6 ++++++
 .../org/apache/spark/sql/KeyValueGroupedDataset.scala     |  6 ++++++
 .../main/scala/org/apache/spark/sql/SparkSession.scala    |  6 ++++++
 .../scala/org/apache/spark/sql/ClientDatasetSuite.scala   |  8 ++++++++
 .../spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala    |  7 +++++++
 .../scala/org/apache/spark/sql/SparkSessionSuite.scala    |  7 +++++++
 .../spark/sql/UserDefinedFunctionE2ETestSuite.scala       | 15 +++++++++++++++
 7 files changed, 55 insertions(+)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index cb7d2c84df5..bdaa4e28ba8 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3336,4 +3336,10 @@ class Dataset[T] private[sql] (
       result.close()
     }
   }
+
+  /**
+   * We cannot deserialize a connect [[Dataset]] because of a class clash on 
the server side. We
+   * null out the instance for now.
+   */
+  private def writeReplace(): Any = null
 }
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 202891c66d7..88c8b6a4f8b 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -979,6 +979,12 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
       outputEncoder = outputEncoder)
     udf.apply(inputEncoders.map(_ => col("*")): 
_*).expr.getCommonInlineUserDefinedFunction
   }
+
+  /**
+   * We cannot deserialize a connect [[KeyValueGroupedDataset]] because of a 
class clash on the
+   * server side. We null out the instance for now.
+   */
+  private def writeReplace(): Any = null
 }
 
 private object KeyValueGroupedDatasetImpl {
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index e902e04e246..7882ea64013 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -714,6 +714,12 @@ class SparkSession private[sql] (
   def clearTags(): Unit = {
     client.clearTags()
   }
+
+  /**
+   * We cannot deserialize a connect [[SparkSession]] because of a class clash 
on the server side.
+   * We null out the instance for now.
+   */
+  private def writeReplace(): Any = null
 }
 
 // The minimal builder needed to create a spark session.
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala
index a521c6745a9..aab31d97e8c 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.connect.proto
 import org.apache.spark.sql.connect.client.{DummySparkConnectService, 
SparkConnectClient}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.ConnectFunSuite
+import org.apache.spark.util.SparkSerDeUtils
 
 // Add sample tests.
 // - sample fraction: simple.sample(0.1)
@@ -172,4 +173,11 @@ class ClientDatasetSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
     val actualPlan = service.getAndClearLatestInputPlan()
     assert(actualPlan.equals(expectedPlan))
   }
+
+  test("serialize as null") {
+    val session = newSparkSession()
+    val ds = session.range(10)
+    val bytes = SparkSerDeUtils.serialize(ds)
+    assert(SparkSerDeUtils.deserialize[Dataset[Long]](bytes) == null)
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index 3e979be73a7..98a947826e3 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
 import org.apache.spark.sql.test.{QueryTest, SQLHelper}
 import org.apache.spark.sql.types._
+import org.apache.spark.util.SparkSerDeUtils
 
 case class ClickEvent(id: String, timestamp: Timestamp)
 
@@ -630,6 +631,12 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
       30,
       3)
   }
+
+  test("serialize as null") {
+    val kvgds = session.range(10).groupByKey(_ % 2)
+    val bytes = SparkSerDeUtils.serialize(kvgds)
+    assert(SparkSerDeUtils.deserialize[KeyValueGroupedDataset[Long, 
Long]](bytes) == null)
+  }
 }
 
 case class K1(a: Long)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
index 90fe8f57d07..4c858262c6e 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
@@ -23,6 +23,7 @@ import scala.util.control.NonFatal
 import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, 
MethodDescriptor}
 
 import org.apache.spark.sql.test.ConnectFunSuite
+import org.apache.spark.util.SparkSerDeUtils
 
 /**
  * Tests for non-dataframe related SparkSession operations.
@@ -261,4 +262,10 @@ class SparkSessionSuite extends ConnectFunSuite {
       .create()
       .close()
   }
+
+  test("serialize as null") {
+    val session = SparkSession.builder().create()
+    val bytes = SparkSerDeUtils.serialize(session)
+    assert(SparkSerDeUtils.deserialize[SparkSession](bytes) == null)
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index 0af8c78a1da..fbc2c1c2662 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -328,4 +328,19 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest {
       IntegerType)
     checkDataset(session.range(2).select(fn($"id", $"id" + 2)).as[Int], 3, 5)
   }
+
+  test("nullified SparkSession/Dataset/KeyValueGroupedDataset in UDF") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val df = session.range(0, 10, 1, 1)
+    val kvgds = df.groupByKey(_ / 2)
+    val f = udf { (i: Long) =>
+      assert(session == null)
+      assert(df == null)
+      assert(kvgds == null)
+      i + 1
+    }
+    val result = df.select(f($"id")).as[Long].head
+    assert(result == 1L)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to