This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e8f58a9c4a64 [SPARK-48370][SPARK-48258][CONNECT][PYTHON][FOLLOW-UP] Refactor local and eager required fields in CheckpointCommand e8f58a9c4a64 is described below commit e8f58a9c4a641b830c5304b34b876e0cd5d3ed8e Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Thu May 23 13:50:34 2024 +0900 [SPARK-48370][SPARK-48258][CONNECT][PYTHON][FOLLOW-UP] Refactor local and eager required fields in CheckpointCommand ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/46683 and https://github.com/apache/spark/pull/46570 that refactors `local` and `eager` required fields in `CheckpointCommand` ### Why are the changes needed? To make the code easier to maintain. ### Does this PR introduce _any_ user-facing change? No, the main change has not been released yet. ### How was this patch tested? Manually tested. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46712 from HyukjinKwon/SPARK-48370-SPARK-48258-followup. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../src/main/protobuf/spark/connect/commands.proto | 8 ++--- .../sql/connect/planner/SparkConnectPlanner.scala | 12 ++----- python/pyspark/sql/connect/dataframe.py | 2 +- python/pyspark/sql/connect/proto/commands_pb2.py | 10 +++--- python/pyspark/sql/connect/proto/commands_pb2.pyi | 41 ++++------------------ .../main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 7 files changed, 21 insertions(+), 56 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index fc9766357cb2..5ac07270b22b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3481,7 +3481,7 @@ class Dataset[T] private[sql] ( sparkSession.newDataset(agnosticEncoder) { builder => val command = sparkSession.newCommand { builder => builder.getCheckpointCommandBuilder - .setLocal(reliableCheckpoint) + .setLocal(!reliableCheckpoint) .setEager(eager) .setRelation(this.plan.getRoot) } diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto index c526f8d3f65d..0e0c55fa34f0 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -497,10 +497,10 @@ message CheckpointCommand { // (Required) The logical plan to checkpoint. Relation relation = 1; - // (Optional) Locally checkpoint using a local temporary + // (Required) Locally checkpoint using a local temporary // directory in Spark Connect server (Spark Driver) - optional bool local = 2; + bool local = 2; - // (Optional) Whether to checkpoint this dataframe immediately. - optional bool eager = 3; + // (Required) Whether to checkpoint this dataframe immediately. + bool eager = 3; } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index cbc60d2873f9..a339469e61cd 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -3523,15 +3523,9 @@ class SparkConnectPlanner( responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = { val target = Dataset .ofRows(session, transformRelation(checkpointCommand.getRelation)) - val checkpointed = if (checkpointCommand.hasLocal && checkpointCommand.hasEager) { - target.localCheckpoint(eager = checkpointCommand.getEager) - } else if (checkpointCommand.hasLocal) { - target.localCheckpoint() - } else if (checkpointCommand.hasEager) { - target.checkpoint(eager = checkpointCommand.getEager) - } else { - target.checkpoint() - } + val checkpointed = target.checkpoint( + eager = checkpointCommand.getEager, + reliableCheckpoint = !checkpointCommand.getLocal) val dfId = UUID.randomUUID().toString logInfo(log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}") diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 510776bb752d..62c73da374bc 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -2096,7 +2096,7 @@ class DataFrame(ParentDataFrame): return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session) def checkpoint(self, eager: bool = True) -> "DataFrame": - cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager) + cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager) _, properties = self._session.client.execute_command(cmd.command(self._session.client)) assert "checkpoint_command_result" in properties checkpointed = properties["checkpoint_command_result"] diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index 43673d9707a9..8f67f817c3f0 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_rel DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xaf\x0c\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x [...] + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xaf\x0c\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -71,8 +71,8 @@ if _descriptor._USE_C_DESCRIPTORS == False: _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_options = b"8\001" _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._options = None _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_options = b"8\001" - _STREAMINGQUERYEVENTTYPE._serialized_start = 10549 - _STREAMINGQUERYEVENTTYPE._serialized_end = 10682 + _STREAMINGQUERYEVENTTYPE._serialized_start = 10518 + _STREAMINGQUERYEVENTTYPE._serialized_end = 10651 _COMMAND._serialized_start = 167 _COMMAND._serialized_end = 1750 _SQLCOMMAND._serialized_start = 1753 @@ -167,6 +167,6 @@ if _descriptor._USE_C_DESCRIPTORS == False: _CREATERESOURCEPROFILECOMMANDRESULT._serialized_end = 10295 _REMOVECACHEDREMOTERELATIONCOMMAND._serialized_start = 10297 _REMOVECACHEDREMOTERELATIONCOMMAND._serialized_end = 10397 - _CHECKPOINTCOMMAND._serialized_start = 10400 - _CHECKPOINTCOMMAND._serialized_end = 10546 + _CHECKPOINTCOMMAND._serialized_start = 10399 + _CHECKPOINTCOMMAND._serialized_end = 10515 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index 61691abbdd85..04d50d5b5e4f 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -2174,55 +2174,26 @@ class CheckpointCommand(google.protobuf.message.Message): def relation(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: """(Required) The logical plan to checkpoint.""" local: builtins.bool - """(Optional) Locally checkpoint using a local temporary + """(Required) Locally checkpoint using a local temporary directory in Spark Connect server (Spark Driver) """ eager: builtins.bool - """(Optional) Whether to checkpoint this dataframe immediately.""" + """(Required) Whether to checkpoint this dataframe immediately.""" def __init__( self, *, relation: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., - local: builtins.bool | None = ..., - eager: builtins.bool | None = ..., + local: builtins.bool = ..., + eager: builtins.bool = ..., ) -> None: ... def HasField( - self, - field_name: typing_extensions.Literal[ - "_eager", - b"_eager", - "_local", - b"_local", - "eager", - b"eager", - "local", - b"local", - "relation", - b"relation", - ], + self, field_name: typing_extensions.Literal["relation", b"relation"] ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "_eager", - b"_eager", - "_local", - b"_local", - "eager", - b"eager", - "local", - b"local", - "relation", - b"relation", + "eager", b"eager", "local", b"local", "relation", b"relation" ], ) -> None: ... - @typing.overload - def WhichOneof( - self, oneof_group: typing_extensions.Literal["_eager", b"_eager"] - ) -> typing_extensions.Literal["eager"] | None: ... - @typing.overload - def WhichOneof( - self, oneof_group: typing_extensions.Literal["_local", b"_local"] - ) -> typing_extensions.Literal["local"] | None: ... global___CheckpointCommand = CheckpointCommand diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3e843e64ebbf..c7511737b2b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -754,7 +754,7 @@ class Dataset[T] private[sql]( * checkpoint directory. If false creates a local checkpoint using * the caching subsystem */ - private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { + private[sql] def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { val actionName = if (reliableCheckpoint) "checkpoint" else "localCheckpoint" withAction(actionName, queryExecution) { physicalPlan => val internalRdd = physicalPlan.execute().map(_.copy()) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org