This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7c5daaaa20b [SPARK-43032][SS][CONNECT] Python SQM bug fix 7c5daaaa20b is described below commit 7c5daaaa20bb012110d5855e5908cc01658355ed Author: Wei Liu <wei....@databricks.com> AuthorDate: Mon May 8 10:08:27 2023 +0900 [SPARK-43032][SS][CONNECT] Python SQM bug fix ### What changes were proposed in this pull request? Some bug fix for streaming ***connect*** python SQM Note that I also changed ***non-connect***'s StreamingQueryManager `get()` API to return an `Optional[StreamingQuery]`. Before it looks like this when you get a non-exist query: ``` >>> a = spark.streams.get("00000000-0000-0001-0000-000000000001") >>> a <pyspark.sql.streaming.query.StreamingQuery object at 0x7f86465702b0> >>> a.id Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/wei.liu/oss-spark/python/pyspark/sql/streaming/query.py", line 78, in id return self._jsq.id().toString() AttributeError: 'NoneType' object has no attribute 'id' ``` But now it looks like: ``` >>> a = spark.streams.get("00000000-0000-0001-0000-000000000001") >>> a.id Traceback (most recent call last): File "<stdin>", line 1, in <module> AttributeError: 'NoneType' object has no attribute 'id' ``` The only difference is the return type, which is not typically honored in Python... But not very sure if that's a breaking change ### Why are the changes needed? Bug fix ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manually tested. Also verified that it won't throw even without this fix so it's not that urgent Closes #41037 from WweiL/SPARK-43032-python-sqm-fix. Authored-by: Wei Liu <wei....@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../src/main/protobuf/spark/connect/commands.proto | 2 +- .../sql/connect/planner/SparkConnectPlanner.scala | 22 +++++++------ python/pyspark/sql/connect/proto/commands_pb2.py | 36 +++++++++++----------- python/pyspark/sql/connect/proto/commands_pb2.pyi | 16 +++++----- python/pyspark/sql/connect/streaming/query.py | 14 ++++++--- python/pyspark/sql/streaming/query.py | 8 +++-- 6 files changed, 54 insertions(+), 44 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto index b929ffa2564..72bc8b5b6ef 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -330,7 +330,7 @@ message StreamingQueryManagerCommand { // active() API, returns a list of active queries. bool active = 1; // get() API, returns the StreamingQuery identified by id. - string get = 2; + string get_query = 2; // awaitAnyTermination() API, wait until any query terminates or timeout. AwaitAnyTerminationCommand await_any_termination = 3; // resetTerminated() API. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 8c43f982ec1..01f1e890630 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2466,16 +2466,18 @@ class SparkConnectPlanner(val session: SparkSession) { .toIterable .asJava) - case StreamingQueryManagerCommand.CommandCase.GET => - val query = session.streams.get(command.getGet) - respBuilder.getQueryBuilder - .setId( - StreamingQueryInstanceId - .newBuilder() - .setId(query.id.toString) - .setRunId(query.runId.toString) - .build()) - .setName(SparkConnectService.convertNullString(query.name)) + case StreamingQueryManagerCommand.CommandCase.GET_QUERY => + val query = session.streams.get(command.getGetQuery) + if (query != null) { + respBuilder.getQueryBuilder + .setId( + StreamingQueryInstanceId + .newBuilder() + .setId(query.id.toString) + .setRunId(query.runId.toString) + .build()) + .setName(SparkConnectService.convertNullString(query.name)) + } case StreamingQueryManagerCommand.CommandCase.AWAIT_ANY_TERMINATION => if (command.getAwaitAnyTermination.hasTimeoutMs) { diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index 9848a40adab..bc764926213 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_rel DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x86\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x [...] + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x86\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x [...] ) @@ -525,21 +525,21 @@ if _descriptor._USE_C_DESCRIPTORS == False: _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 5817 _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 5873 _STREAMINGQUERYMANAGERCOMMAND._serialized_start = 5891 - _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 6230 - _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start = 6140 - _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 6219 - _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 6233 - _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 6942 - _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 6636 - _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 6763 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 6765 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end = 6866 - _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start = 6868 - _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 6927 - _GETRESOURCESCOMMAND._serialized_start = 6944 - _GETRESOURCESCOMMAND._serialized_end = 6965 - _GETRESOURCESCOMMANDRESULT._serialized_start = 6968 - _GETRESOURCESCOMMANDRESULT._serialized_end = 7180 - _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 7084 - _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 7180 + _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 6241 + _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start = 6151 + _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 6230 + _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 6244 + _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 6953 + _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 6647 + _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 6774 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 6776 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end = 6877 + _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start = 6879 + _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 6938 + _GETRESOURCESCOMMAND._serialized_start = 6955 + _GETRESOURCESCOMMAND._serialized_end = 6976 + _GETRESOURCESCOMMANDRESULT._serialized_start = 6979 + _GETRESOURCESCOMMANDRESULT._serialized_end = 7191 + _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 7095 + _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 7191 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index 6fec61b02dd..2c80614c3fd 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -1283,12 +1283,12 @@ class StreamingQueryManagerCommand(google.protobuf.message.Message): ) -> typing_extensions.Literal["timeout_ms"] | None: ... ACTIVE_FIELD_NUMBER: builtins.int - GET_FIELD_NUMBER: builtins.int + GET_QUERY_FIELD_NUMBER: builtins.int AWAIT_ANY_TERMINATION_FIELD_NUMBER: builtins.int RESET_TERMINATED_FIELD_NUMBER: builtins.int active: builtins.bool """active() API, returns a list of active queries.""" - get: builtins.str + get_query: builtins.str """get() API, returns the StreamingQuery identified by id.""" @property def await_any_termination( @@ -1301,7 +1301,7 @@ class StreamingQueryManagerCommand(google.protobuf.message.Message): self, *, active: builtins.bool = ..., - get: builtins.str = ..., + get_query: builtins.str = ..., await_any_termination: global___StreamingQueryManagerCommand.AwaitAnyTerminationCommand | None = ..., reset_terminated: builtins.bool = ..., @@ -1315,8 +1315,8 @@ class StreamingQueryManagerCommand(google.protobuf.message.Message): b"await_any_termination", "command", b"command", - "get", - b"get", + "get_query", + b"get_query", "reset_terminated", b"reset_terminated", ], @@ -1330,8 +1330,8 @@ class StreamingQueryManagerCommand(google.protobuf.message.Message): b"await_any_termination", "command", b"command", - "get", - b"get", + "get_query", + b"get_query", "reset_terminated", b"reset_terminated", ], @@ -1339,7 +1339,7 @@ class StreamingQueryManagerCommand(google.protobuf.message.Message): def WhichOneof( self, oneof_group: typing_extensions.Literal["command", b"command"] ) -> typing_extensions.Literal[ - "active", "get", "await_any_termination", "reset_terminated" + "active", "get_query", "await_any_termination", "reset_terminated" ] | None: ... global___StreamingQueryManagerCommand = StreamingQueryManagerCommand diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index 606c4d4febc..e5aa881c990 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -187,11 +187,15 @@ class StreamingQueryManager: active.__doc__ = PySparkStreamingQueryManager.active.__doc__ - def get(self, id: str) -> StreamingQuery: + def get(self, id: str) -> Optional[StreamingQuery]: cmd = pb2.StreamingQueryManagerCommand() - cmd.get = id - query = self._execute_streaming_query_manager_cmd(cmd).query - return StreamingQuery(self._session, query.id.id, query.id.run_id, query.name) + cmd.get_query = id + response = self._execute_streaming_query_manager_cmd(cmd) + if response.HasField("query"): + query = response.query + return StreamingQuery(self._session, query.id.id, query.id.run_id, query.name) + else: + return None get.__doc__ = PySparkStreamingQueryManager.get.__doc__ @@ -221,7 +225,7 @@ class StreamingQueryManager: def resetTerminated(self) -> None: cmd = pb2.StreamingQueryManagerCommand() cmd.reset_terminated = True - self._execute_streaming_query_manager_cmd(cmd).active.active_queries + self._execute_streaming_query_manager_cmd(cmd) return None resetTerminated.__doc__ = PySparkStreamingQueryManager.resetTerminated.__doc__ diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index b6268dcdb18..ac7a1acfcaa 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -445,7 +445,7 @@ class StreamingQueryManager: """ return [StreamingQuery(jsq) for jsq in self._jsqm.active()] - def get(self, id: str) -> StreamingQuery: + def get(self, id: str) -> Optional[StreamingQuery]: """ Returns an active query from this :class:`SparkSession`. @@ -484,7 +484,11 @@ class StreamingQueryManager: True >>> sq.stop() """ - return StreamingQuery(self._jsqm.get(id)) + query = self._jsqm.get(id) + if query is not None: + return StreamingQuery(query) + else: + return None def awaitAnyTermination(self, timeout: Optional[int] = None) -> Optional[bool]: """ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org