This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b9ca91dde94c [SPARK-47712][CONNECT] Allow connect plugins to create 
and process Datasets
b9ca91dde94c is described below

commit b9ca91dde94c5ac6eeae9bb5818099adbc93006c
Author: Tom van Bussel <tom.vanbus...@databricks.com>
AuthorDate: Fri Apr 5 10:42:43 2024 +0900

    [SPARK-47712][CONNECT] Allow connect plugins to create and process Datasets
    
    ### What changes were proposed in this pull request?
    
    This PR adds new versions of `SparkSession.createDataset` and 
`SparkSession.createDataFrame` that take an `Array[Byte]` as input. The older 
versions that take a `protobuf.Any` are deprecated. This PR also adds new 
versions of `SparkConnectPlanner.transformRelation` and 
`SparkConnectPlanner.transformExpression` that take an `Array[Byte]`.
    
    ### Why are the changes needed?
    
    Without these changes it's difficult to create plugins for Spark Connect. 
The methods above used to take a protobuf class that is shaded as input, 
meaning that that plugins had to shade these classes in the exact same way. Now 
they can just serialize the protobuf object to bytes and pass that in instead.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Tests were added
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45850 from tomvanbussel/SPARK-47712.
    
    Authored-by: Tom van Bussel <tom.vanbus...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/scala/org/apache/spark/sql/Column.scala   |   6 +++++
 .../scala/org/apache/spark/sql/SparkSession.scala  |  14 ++++++++++-
 .../org/apache/spark/sql/ClientDatasetSuite.scala  |  14 ++++++++++-
 .../apache/spark/sql/PlanGenerationTestSuite.scala |  26 +++++++++++++++++++--
 .../expression_extension_deprecated.explain        |   2 ++
 .../relation_extension_deprecated.explain          |   1 +
 .../queries/expression_extension_deprecated.json   |  26 +++++++++++++++++++++
 .../expression_extension_deprecated.proto.bin      | Bin 0 -> 127 bytes
 .../queries/relation_extension_deprecated.json     |  16 +++++++++++++
 .../relation_extension_deprecated.proto.bin        | Bin 0 -> 108 bytes
 .../sql/connect/planner/SparkConnectPlanner.scala  |  11 +++++++++
 .../plugin/SparkConnectPluginRegistrySuite.scala   |   5 ++--
 12 files changed, 114 insertions(+), 7 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
index dec699f4f1a8..c23d49440248 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
@@ -1351,10 +1351,16 @@ private[sql] object Column {
   }
 
   @DeveloperApi
+  @deprecated("Use forExtension(Array[Byte]) instead", "4.0.0")
   def apply(extension: com.google.protobuf.Any): Column = {
     apply(_.setExtension(extension))
   }
 
+  @DeveloperApi
+  def forExtension(extension: Array[Byte]): Column = {
+    apply(_.setExtension(com.google.protobuf.Any.parseFrom(extension)))
+  }
+
   private[sql] def fn(name: String, inputs: Column*): Column = {
     fn(name, isDistinct = false, inputs: _*)
   }
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index adee5b33fb4e..1e467a864442 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -496,17 +496,29 @@ class SparkSession private[sql] (
   }
 
   @DeveloperApi
+  @deprecated("Use newDataFrame(Array[Byte]) instead", "4.0.0")
   def newDataFrame(extension: com.google.protobuf.Any): DataFrame = {
-    newDataset(extension, UnboundRowEncoder)
+    newDataFrame(_.setExtension(extension))
   }
 
   @DeveloperApi
+  @deprecated("Use newDataFrame(Array[Byte], AgnosticEncoder[T]) instead", 
"4.0.0")
   def newDataset[T](
       extension: com.google.protobuf.Any,
       encoder: AgnosticEncoder[T]): Dataset[T] = {
     newDataset(encoder)(_.setExtension(extension))
   }
 
+  @DeveloperApi
+  def newDataFrame(extension: Array[Byte]): DataFrame = {
+    newDataFrame(_.setExtension(com.google.protobuf.Any.parseFrom(extension)))
+  }
+
+  @DeveloperApi
+  def newDataset[T](extension: Array[Byte], encoder: AgnosticEncoder[T]): 
Dataset[T] = {
+    
newDataset(encoder)(_.setExtension(com.google.protobuf.Any.parseFrom(extension)))
+  }
+
   private[sql] def newCommand[T](f: proto.Command.Builder => Unit): 
proto.Command = {
     val builder = proto.Command.newBuilder()
     f(builder)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala
index 041b09283658..4a32b8460bce 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala
@@ -162,7 +162,7 @@ class ClientDatasetSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
     }
   }
 
-  test("command extension") {
+  test("command extension deprecated") {
     val extension = 
proto.ExamplePluginCommand.newBuilder().setCustomField("abc").build()
     val command = proto.Command
       .newBuilder()
@@ -174,6 +174,18 @@ class ClientDatasetSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
     assert(actualPlan.equals(expectedPlan))
   }
 
+  test("command extension") {
+    val extension = 
proto.ExamplePluginCommand.newBuilder().setCustomField("abc").build()
+    val command = proto.Command
+      .newBuilder()
+      .setExtension(com.google.protobuf.Any.pack(extension))
+      .build()
+    val expectedPlan = proto.Plan.newBuilder().setCommand(command).build()
+    ss.execute(com.google.protobuf.Any.pack(extension).toByteArray)
+    val actualPlan = service.getAndClearLatestInputPlan()
+    assert(actualPlan.equals(expectedPlan))
+  }
+
   test("serialize as null") {
     val session = newSparkSession()
     val ds = session.range(10)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 5fde8b04735b..5844df8a4889 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -3191,7 +3191,7 @@ class PlanGenerationTestSuite
   }
 
   /* Extensions */
-  test("relation extension") {
+  test("relation extension deprecated") {
     val input = proto.ExamplePluginRelation
       .newBuilder()
       .setInput(simple.plan.getRoot)
@@ -3199,7 +3199,7 @@ class PlanGenerationTestSuite
     session.newDataFrame(com.google.protobuf.Any.pack(input))
   }
 
-  test("expression extension") {
+  test("expression extension deprecated") {
     val extension = proto.ExamplePluginExpression
       .newBuilder()
       .setChild(
@@ -3213,6 +3213,28 @@ class PlanGenerationTestSuite
     simple.select(Column(com.google.protobuf.Any.pack(extension)))
   }
 
+  test("relation extension") {
+    val input = proto.ExamplePluginRelation
+      .newBuilder()
+      .setInput(simple.plan.getRoot)
+      .build()
+    session.newDataFrame(com.google.protobuf.Any.pack(input).toByteArray)
+  }
+
+  test("expression extension") {
+    val extension = proto.ExamplePluginExpression
+      .newBuilder()
+      .setChild(
+        proto.Expression
+          .newBuilder()
+          .setUnresolvedAttribute(proto.Expression.UnresolvedAttribute
+            .newBuilder()
+            .setUnparsedIdentifier("id")))
+      .setCustomField("abc")
+      .build()
+    
simple.select(Column.forExtension(com.google.protobuf.Any.pack(extension).toByteArray))
+  }
+
   test("crosstab") {
     simple.stat.crosstab("a", "b")
   }
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/expression_extension_deprecated.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/expression_extension_deprecated.explain
new file mode 100644
index 000000000000..7426332004a8
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/expression_extension_deprecated.explain
@@ -0,0 +1,2 @@
+Project [id#0L AS abc#0L]
++- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/relation_extension_deprecated.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/relation_extension_deprecated.explain
new file mode 100644
index 000000000000..df724a7dd185
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/relation_extension_deprecated.explain
@@ -0,0 +1 @@
+LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/expression_extension_deprecated.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension_deprecated.json
new file mode 100644
index 000000000000..acfb3cc2333d
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension_deprecated.json
@@ -0,0 +1,26 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "project": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "expressions": [{
+      "extension": {
+        "@type": "type.googleapis.com/spark.connect.ExamplePluginExpression",
+        "child": {
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "id"
+          }
+        },
+        "customField": "abc"
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/expression_extension_deprecated.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension_deprecated.proto.bin
new file mode 100644
index 000000000000..24669eba6423
Binary files /dev/null and 
b/connector/connect/common/src/test/resources/query-tests/queries/expression_extension_deprecated.proto.bin
 differ
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/relation_extension_deprecated.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension_deprecated.json
new file mode 100644
index 000000000000..47ceba13ca7e
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension_deprecated.json
@@ -0,0 +1,16 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "extension": {
+    "@type": "type.googleapis.com/spark.connect.ExamplePluginRelation",
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    }
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/relation_extension_deprecated.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension_deprecated.proto.bin
new file mode 100644
index 000000000000..680bb550eca5
Binary files /dev/null and 
b/connector/connect/common/src/test/resources/query-tests/queries/relation_extension_deprecated.proto.bin
 differ
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 1894ab984490..40dc7f88255e 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -30,6 +30,7 @@ import io.grpc.stub.StreamObserver
 import org.apache.commons.lang3.exception.ExceptionUtils
 
 import org.apache.spark.{Partition, SparkEnv, TaskContext}
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.{CreateResourceProfileCommand, 
ExecutePlanResponse, SqlCommand, StreamingForeachFunction, 
StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, 
StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, 
WriteStreamOperationStart, WriteStreamOperationStartResult}
@@ -202,6 +203,11 @@ class SparkConnectPlanner(
     plan
   }
 
+  @DeveloperApi
+  def transformRelation(bytes: Array[Byte]): LogicalPlan = {
+    transformRelation(proto.Relation.parseFrom(bytes))
+  }
+
   private def transformRelationPlugin(extension: ProtoAny): LogicalPlan = {
     SparkConnectPluginRegistry.relationRegistry
       // Lazily traverse the collection.
@@ -1470,6 +1476,11 @@ class SparkConnectPlanner(
     }
   }
 
+  @DeveloperApi
+  def transformExpression(bytes: Array[Byte]): Expression = {
+    transformExpression(proto.Expression.parseFrom(bytes))
+  }
+
   private def toNamedExpression(expr: Expression): NamedExpression = expr 
match {
     case named: NamedExpression => named
     case expr => UnresolvedAlias(expr)
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
index ff8cac7a35d6..a213a36168e8 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
@@ -68,7 +68,7 @@ class ExampleRelationPlugin extends RelationPlugin {
       return Optional.empty()
     }
     val plugin = rel.unpack(classOf[proto.ExamplePluginRelation])
-    Optional.of(planner.transformRelation(plugin.getInput))
+    Optional.of(planner.transformRelation(plugin.getInput.toByteArray))
   }
 }
 
@@ -82,8 +82,7 @@ class ExampleExpressionPlugin extends ExpressionPlugin {
     }
     val exp = rel.unpack(classOf[proto.ExamplePluginExpression])
     Optional.of(
-      Alias(planner.transformExpression(exp.getChild), 
exp.getCustomField)(explicitMetadata =
-        None))
+      Alias(planner.transformExpression(exp.getChild.toByteArray), 
exp.getCustomField)())
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to