This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e64afb6f9d4 [SPARK-40899][CONNECT] Make UserContext extensible
e64afb6f9d4 is described below

commit e64afb6f9d47ded64d85091a4fefac2d8657b1c0
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Tue Oct 25 11:18:00 2022 +0800

    [SPARK-40899][CONNECT] Make UserContext extensible
    
    ### What changes were proposed in this pull request?
    
    Different systems will need different metadata that is passed as the user 
context during the request. To be able to handle the different systems 
seamlessly, make the `UserContext` extensible with `google.protobuf.Any`.
    
    ### Why are the changes needed?
    Extensibility.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    Closes #38374 from grundprinzip/SPARK-40899.
    
    Authored-by: Martin Grund <martin.gr...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../src/main/protobuf/spark/connect/base.proto     |  7 +++
 .../messages/ConnectProtoMessagesSuite.scala       | 51 ++++++++++++++++++++++
 python/pyspark/sql/connect/proto/base_pb2.py       | 51 +++++++++++-----------
 python/pyspark/sql/connect/proto/base_pb2.pyi      | 18 +++++++-
 4 files changed, 101 insertions(+), 26 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto 
b/connector/connect/src/main/protobuf/spark/connect/base.proto
index 390a8b156dc..dff1734335e 100644
--- a/connector/connect/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/base.proto
@@ -19,6 +19,7 @@ syntax = 'proto3';
 
 package spark.connect;
 
+import "google/protobuf/any.proto";
 import "spark/connect/commands.proto";
 import "spark/connect/relations.proto";
 
@@ -51,6 +52,12 @@ message Request {
   message UserContext {
     string user_id = 1;
     string user_name = 2;
+
+    // To extend the existing user context message that is used to identify 
incoming requests,
+    // Spark Connect leverages the Any protobuf type that can be used to 
inject arbitrary other
+    // messages into this message. Extensions are stored as a `repeated` type 
to be able to
+    // handle multiple active extensions.
+    repeated google.protobuf.Any extensions = 999;
   }
 }
 
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
new file mode 100644
index 00000000000..4132cca9108
--- /dev/null
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.messages
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.connect.proto
+
+class ConnectProtoMessagesSuite extends SparkFunSuite {
+  test("UserContext can deal with extensions") {
+    // Create the builder.
+    val builder = 
proto.Request.UserContext.newBuilder().setUserId("1").setUserName("Martin")
+
+    // Create the extension value.
+    val lit = proto.Expression
+      .newBuilder()
+      .setLiteral(proto.Expression.Literal.newBuilder().setI32(32).build())
+    // Pack the extension into Any.
+    val aval = com.google.protobuf.Any.pack(lit.build())
+    // Add Any to the repeated field list.
+    builder.addExtensions(aval)
+    // Create serialized value.
+    val serialized = builder.build().toByteArray
+
+    // Now, read the serialized value.
+    val result = proto.Request.UserContext.parseFrom(serialized)
+    assert(result.getUserId.equals("1"))
+    assert(result.getUserName.equals("Martin"))
+    assert(result.getExtensionsCount == 1)
+
+    val ext = result.getExtensions(0)
+    assert(ext.is(classOf[proto.Expression]))
+    val extLit = ext.unpack(classOf[proto.Expression])
+    assert(extLit.hasLiteral)
+    assert(extLit.getLiteral.hasI32)
+    assert(extLit.getLiteral.getI32 == 32)
+  }
+}
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py 
b/python/pyspark/sql/connect/proto/base_pb2.py
index 8de6565bae8..408872dbb66 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -28,12 +28,13 @@ from google.protobuf import symbol_database as 
_symbol_database
 _sym_db = _symbol_database.Default()
 
 
+from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
 from pyspark.sql.connect.proto import commands_pb2 as 
spark_dot_connect_dot_commands__pb2
 from pyspark.sql.connect.proto import relations_pb2 as 
spark_dot_connect_dot_relations__pb2
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xdb\x01\n\x07Request\x12\x1b\n\tclient_id\x18\x01
 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 
\x01(\x0b\x32".spark.connect.Request.UserContextR\x0buse [...]
+    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01
 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 
\x01(\x0b\x32".spark.co [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -44,28 +45,28 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
     _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None
     _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = 
b"8\001"
-    _PLAN._serialized_start = 104
-    _PLAN._serialized_end = 220
-    _REQUEST._serialized_start = 223
-    _REQUEST._serialized_end = 442
-    _REQUEST_USERCONTEXT._serialized_start = 375
-    _REQUEST_USERCONTEXT._serialized_end = 442
-    _RESPONSE._serialized_start = 445
-    _RESPONSE._serialized_end = 1413
-    _RESPONSE_ARROWBATCH._serialized_start = 674
-    _RESPONSE_ARROWBATCH._serialized_end = 849
-    _RESPONSE_JSONBATCH._serialized_start = 851
-    _RESPONSE_JSONBATCH._serialized_end = 911
-    _RESPONSE_METRICS._serialized_start = 914
-    _RESPONSE_METRICS._serialized_end = 1398
-    _RESPONSE_METRICS_METRICOBJECT._serialized_start = 998
-    _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1308
-    _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 
1196
-    _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1308
-    _RESPONSE_METRICS_METRICVALUE._serialized_start = 1310
-    _RESPONSE_METRICS_METRICVALUE._serialized_end = 1398
-    _ANALYZERESPONSE._serialized_start = 1416
-    _ANALYZERESPONSE._serialized_end = 1571
-    _SPARKCONNECTSERVICE._serialized_start = 1574
-    _SPARKCONNECTSERVICE._serialized_end = 1736
+    _PLAN._serialized_start = 131
+    _PLAN._serialized_end = 247
+    _REQUEST._serialized_start = 250
+    _REQUEST._serialized_end = 524
+    _REQUEST_USERCONTEXT._serialized_start = 402
+    _REQUEST_USERCONTEXT._serialized_end = 524
+    _RESPONSE._serialized_start = 527
+    _RESPONSE._serialized_end = 1495
+    _RESPONSE_ARROWBATCH._serialized_start = 756
+    _RESPONSE_ARROWBATCH._serialized_end = 931
+    _RESPONSE_JSONBATCH._serialized_start = 933
+    _RESPONSE_JSONBATCH._serialized_end = 993
+    _RESPONSE_METRICS._serialized_start = 996
+    _RESPONSE_METRICS._serialized_end = 1480
+    _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1080
+    _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1390
+    _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 
1278
+    _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1390
+    _RESPONSE_METRICS_METRICVALUE._serialized_start = 1392
+    _RESPONSE_METRICS_METRICVALUE._serialized_end = 1480
+    _ANALYZERESPONSE._serialized_start = 1498
+    _ANALYZERESPONSE._serialized_end = 1653
+    _SPARKCONNECTSERVICE._serialized_start = 1656
+    _SPARKCONNECTSERVICE._serialized_end = 1818
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi 
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index bcac7d11a80..bb3a6578cf7 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -35,6 +35,7 @@ limitations under the License.
 """
 import builtins
 import collections.abc
+import google.protobuf.any_pb2
 import google.protobuf.descriptor
 import google.protobuf.internal.containers
 import google.protobuf.message
@@ -102,17 +103,32 @@ class Request(google.protobuf.message.Message):
 
         USER_ID_FIELD_NUMBER: builtins.int
         USER_NAME_FIELD_NUMBER: builtins.int
+        EXTENSIONS_FIELD_NUMBER: builtins.int
         user_id: builtins.str
         user_name: builtins.str
+        @property
+        def extensions(
+            self,
+        ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+            google.protobuf.any_pb2.Any
+        ]:
+            """To extend the existing user context message that is used to 
identify incoming requests,
+            Spark Connect leverages the Any protobuf type that can be used to 
inject arbitrary other
+            messages into this message. Extensions are stored as a `repeated` 
type to be able to
+            handle multiple active extensions.
+            """
         def __init__(
             self,
             *,
             user_id: builtins.str = ...,
             user_name: builtins.str = ...,
+            extensions: collections.abc.Iterable[google.protobuf.any_pb2.Any] 
| None = ...,
         ) -> None: ...
         def ClearField(
             self,
-            field_name: typing_extensions.Literal["user_id", b"user_id", 
"user_name", b"user_name"],
+            field_name: typing_extensions.Literal[
+                "extensions", b"extensions", "user_id", b"user_id", 
"user_name", b"user_name"
+            ],
         ) -> None: ...
 
     CLIENT_ID_FIELD_NUMBER: builtins.int


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to