This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e8f58a9c4a64 [SPARK-48370][SPARK-48258][CONNECT][PYTHON][FOLLOW-UP] 
Refactor local and eager required fields in CheckpointCommand
e8f58a9c4a64 is described below

commit e8f58a9c4a641b830c5304b34b876e0cd5d3ed8e
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Thu May 23 13:50:34 2024 +0900

    [SPARK-48370][SPARK-48258][CONNECT][PYTHON][FOLLOW-UP] Refactor local and 
eager required fields in CheckpointCommand
    
    ### What changes were proposed in this pull request?
    
    This PR is a followup of https://github.com/apache/spark/pull/46683 and 
https://github.com/apache/spark/pull/46570 that refactors `local` and `eager` 
required fields in `CheckpointCommand`
    
    ### Why are the changes needed?
    
    To make the code easier to maintain.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, the main change has not been released yet.
    
    ### How was this patch tested?
    
    Manually tested.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46712 from HyukjinKwon/SPARK-48370-SPARK-48258-followup.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  2 +-
 .../src/main/protobuf/spark/connect/commands.proto |  8 ++---
 .../sql/connect/planner/SparkConnectPlanner.scala  | 12 ++-----
 python/pyspark/sql/connect/dataframe.py            |  2 +-
 python/pyspark/sql/connect/proto/commands_pb2.py   | 10 +++---
 python/pyspark/sql/connect/proto/commands_pb2.pyi  | 41 ++++------------------
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  2 +-
 7 files changed, 21 insertions(+), 56 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index fc9766357cb2..5ac07270b22b 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3481,7 +3481,7 @@ class Dataset[T] private[sql] (
     sparkSession.newDataset(agnosticEncoder) { builder =>
       val command = sparkSession.newCommand { builder =>
         builder.getCheckpointCommandBuilder
-          .setLocal(reliableCheckpoint)
+          .setLocal(!reliableCheckpoint)
           .setEager(eager)
           .setRelation(this.plan.getRoot)
       }
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index c526f8d3f65d..0e0c55fa34f0 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -497,10 +497,10 @@ message CheckpointCommand {
   // (Required) The logical plan to checkpoint.
   Relation relation = 1;
 
-  // (Optional) Locally checkpoint using a local temporary
+  // (Required) Locally checkpoint using a local temporary
   // directory in Spark Connect server (Spark Driver)
-  optional bool local = 2;
+  bool local = 2;
 
-  // (Optional) Whether to checkpoint this dataframe immediately.
-  optional bool eager = 3;
+  // (Required) Whether to checkpoint this dataframe immediately.
+  bool eager = 3;
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index cbc60d2873f9..a339469e61cd 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -3523,15 +3523,9 @@ class SparkConnectPlanner(
       responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
     val target = Dataset
       .ofRows(session, transformRelation(checkpointCommand.getRelation))
-    val checkpointed = if (checkpointCommand.hasLocal && 
checkpointCommand.hasEager) {
-      target.localCheckpoint(eager = checkpointCommand.getEager)
-    } else if (checkpointCommand.hasLocal) {
-      target.localCheckpoint()
-    } else if (checkpointCommand.hasEager) {
-      target.checkpoint(eager = checkpointCommand.getEager)
-    } else {
-      target.checkpoint()
-    }
+    val checkpointed = target.checkpoint(
+      eager = checkpointCommand.getEager,
+      reliableCheckpoint = !checkpointCommand.getLocal)
 
     val dfId = UUID.randomUUID().toString
     logInfo(log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}")
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 510776bb752d..62c73da374bc 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -2096,7 +2096,7 @@ class DataFrame(ParentDataFrame):
         return DataFrame(plan.Offset(child=self._plan, offset=n), 
session=self._session)
 
     def checkpoint(self, eager: bool = True) -> "DataFrame":
-        cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager)
+        cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager)
         _, properties = 
self._session.client.execute_command(cmd.command(self._session.client))
         assert "checkpoint_command_result" in properties
         checkpointed = properties["checkpoint_command_result"]
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py 
b/python/pyspark/sql/connect/proto/commands_pb2.py
index 43673d9707a9..8f67f817c3f0 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import relations_pb2 as 
spark_dot_connect_dot_rel
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xaf\x0c\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
 
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
 
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x
 [...]
+    
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xaf\x0c\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
 
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
 
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -71,8 +71,8 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_options = b"8\001"
     _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._options = None
     _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_options = b"8\001"
-    _STREAMINGQUERYEVENTTYPE._serialized_start = 10549
-    _STREAMINGQUERYEVENTTYPE._serialized_end = 10682
+    _STREAMINGQUERYEVENTTYPE._serialized_start = 10518
+    _STREAMINGQUERYEVENTTYPE._serialized_end = 10651
     _COMMAND._serialized_start = 167
     _COMMAND._serialized_end = 1750
     _SQLCOMMAND._serialized_start = 1753
@@ -167,6 +167,6 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _CREATERESOURCEPROFILECOMMANDRESULT._serialized_end = 10295
     _REMOVECACHEDREMOTERELATIONCOMMAND._serialized_start = 10297
     _REMOVECACHEDREMOTERELATIONCOMMAND._serialized_end = 10397
-    _CHECKPOINTCOMMAND._serialized_start = 10400
-    _CHECKPOINTCOMMAND._serialized_end = 10546
+    _CHECKPOINTCOMMAND._serialized_start = 10399
+    _CHECKPOINTCOMMAND._serialized_end = 10515
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi 
b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index 61691abbdd85..04d50d5b5e4f 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -2174,55 +2174,26 @@ class 
CheckpointCommand(google.protobuf.message.Message):
     def relation(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
         """(Required) The logical plan to checkpoint."""
     local: builtins.bool
-    """(Optional) Locally checkpoint using a local temporary
+    """(Required) Locally checkpoint using a local temporary
     directory in Spark Connect server (Spark Driver)
     """
     eager: builtins.bool
-    """(Optional) Whether to checkpoint this dataframe immediately."""
+    """(Required) Whether to checkpoint this dataframe immediately."""
     def __init__(
         self,
         *,
         relation: pyspark.sql.connect.proto.relations_pb2.Relation | None = 
...,
-        local: builtins.bool | None = ...,
-        eager: builtins.bool | None = ...,
+        local: builtins.bool = ...,
+        eager: builtins.bool = ...,
     ) -> None: ...
     def HasField(
-        self,
-        field_name: typing_extensions.Literal[
-            "_eager",
-            b"_eager",
-            "_local",
-            b"_local",
-            "eager",
-            b"eager",
-            "local",
-            b"local",
-            "relation",
-            b"relation",
-        ],
+        self, field_name: typing_extensions.Literal["relation", b"relation"]
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
-            "_eager",
-            b"_eager",
-            "_local",
-            b"_local",
-            "eager",
-            b"eager",
-            "local",
-            b"local",
-            "relation",
-            b"relation",
+            "eager", b"eager", "local", b"local", "relation", b"relation"
         ],
     ) -> None: ...
-    @typing.overload
-    def WhichOneof(
-        self, oneof_group: typing_extensions.Literal["_eager", b"_eager"]
-    ) -> typing_extensions.Literal["eager"] | None: ...
-    @typing.overload
-    def WhichOneof(
-        self, oneof_group: typing_extensions.Literal["_local", b"_local"]
-    ) -> typing_extensions.Literal["local"] | None: ...
 
 global___CheckpointCommand = CheckpointCommand
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 3e843e64ebbf..c7511737b2b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -754,7 +754,7 @@ class Dataset[T] private[sql](
    *                           checkpoint directory. If false creates a local 
checkpoint using
    *                           the caching subsystem
    */
-  private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): 
Dataset[T] = {
+  private[sql] def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): 
Dataset[T] = {
     val actionName = if (reliableCheckpoint) "checkpoint" else 
"localCheckpoint"
     withAction(actionName, queryExecution) { physicalPlan =>
       val internalRdd = physicalPlan.execute().map(_.copy())


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to