This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 06aafb1eeaf3 [SPARK-48258][PYTHON][CONNECT][FOLLOW-UP] Bind relation 
ID to the plan instead of DataFrame
06aafb1eeaf3 is described below

commit 06aafb1eeaf32bdc7abce5bb4a9ffb474a9e61ae
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Thu May 23 08:52:36 2024 +0900

    [SPARK-48258][PYTHON][CONNECT][FOLLOW-UP] Bind relation ID to the plan 
instead of DataFrame
    
    ### What changes were proposed in this pull request?
    
    This PR addresses 
https://github.com/apache/spark/pull/46683#discussion_r1608527529 comment 
within Python, by using ID at the plan instead of DataFrame itself.
    
    ### Why are the changes needed?
    
    Because the DataFrame holds the relation ID, if DataFrame B are derived 
from DataFrame A, and DataFrame A is garbage-collected, then the cache might 
not exist anymore. See the example below:
    
    ```python
    df = spark.range(1).localCheckpoint()
    df2 = df.repartition(10)
    del df
    df2.collect()
    ```
    
    ```
    pyspark.errors.exceptions.connect.SparkConnectGrpcException: 
(org.apache.spark.sql.connect.common.InvalidPlanInput) No DataFrame with id 
a4efa660-897c-4500-bd4e-bd57cd0263d2 is found in the session 
cd4764b4-90a9-4249-9140-12a6e4a98cd3
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, the main change has not been released out yet.
    
    ### How was this patch tested?
    
    Manually tested, and added a unittest.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46694 from HyukjinKwon/SPARK-48258-followup.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/conversion.py           |  5 +-
 python/pyspark/sql/connect/dataframe.py            | 38 ---------------
 python/pyspark/sql/connect/plan.py                 | 54 ++++++++++++++++++----
 python/pyspark/sql/connect/session.py              |  2 +-
 .../sql/tests/connect/test_connect_basic.py        | 53 ++++++++++++++++++---
 5 files changed, 97 insertions(+), 55 deletions(-)

diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index b1cf88e40a4e..1c205586d609 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -577,7 +577,8 @@ def proto_to_remote_cached_dataframe(relation: 
pb2.CachedRemoteRelation) -> "Dat
     from pyspark.sql.connect.session import SparkSession
     import pyspark.sql.connect.plan as plan
 
+    session = SparkSession.active()
     return DataFrame(
-        plan=plan.CachedRemoteRelation(relation.relation_id),
-        session=SparkSession.active(),
+        plan=plan.CachedRemoteRelation(relation.relation_id, session),
+        session=session,
     )
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 3725bc3ba0e4..510776bb752d 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -16,7 +16,6 @@
 #
 
 # mypy: disable-error-code="override"
-from pyspark.sql.connect.proto import base_pb2 as 
spark_dot_connect_dot_base__pb2
 from pyspark.errors.exceptions.base import (
     SessionNotSameException,
     PySparkIndexError,
@@ -138,41 +137,6 @@ class DataFrame(ParentDataFrame):
         # by __repr__ and _repr_html_ while eager evaluation opens.
         self._support_repr_html = False
         self._cached_schema: Optional[StructType] = None
-        self._cached_remote_relation_id: Optional[str] = None
-
-    def __del__(self) -> None:
-        # If session is already closed, all cached DataFrame should be 
released.
-        if not self._session.client.is_closed and 
self._cached_remote_relation_id is not None:
-            try:
-                command = plan.RemoveRemoteCachedRelation(
-                    
plan.CachedRemoteRelation(relationId=self._cached_remote_relation_id)
-                ).command(session=self._session.client)
-                req = 
self._session.client._execute_plan_request_with_metadata()
-                if self._session.client._user_id:
-                    req.user_context.user_id = self._session.client._user_id
-                req.plan.command.CopyFrom(command)
-
-                for attempt in self._session.client._retrying():
-                    with attempt:
-                        # !!HACK ALERT!!
-                        # unary_stream does not work on Python's exit for an 
unknown reasons
-                        # Therefore, here we open unary_unary channel instead.
-                        # See also :class:`SparkConnectServiceStub`.
-                        request_serializer = (
-                            
spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString
-                        )
-                        response_deserializer = (
-                            
spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString
-                        )
-                        channel = self._session.client._channel.unary_unary(
-                            "/spark.connect.SparkConnectService/ExecutePlan",
-                            request_serializer=request_serializer,
-                            response_deserializer=response_deserializer,
-                        )
-                        metadata = self._session.client._builder.metadata()
-                        channel(req, metadata=metadata)  # type: 
ignore[arg-type]
-            except Exception as e:
-                warnings.warn(f"RemoveRemoteCachedRelation failed with 
exception: {e}.")
 
     def __reduce__(self) -> Tuple:
         """
@@ -2137,7 +2101,6 @@ class DataFrame(ParentDataFrame):
         assert "checkpoint_command_result" in properties
         checkpointed = properties["checkpoint_command_result"]
         assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
-        checkpointed._cached_remote_relation_id = 
checkpointed._plan._relationId
         return checkpointed
 
     def localCheckpoint(self, eager: bool = True) -> "DataFrame":
@@ -2146,7 +2109,6 @@ class DataFrame(ParentDataFrame):
         assert "checkpoint_command_result" in properties
         checkpointed = properties["checkpoint_command_result"]
         assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
-        checkpointed._cached_remote_relation_id = 
checkpointed._plan._relationId
         return checkpointed
 
     if not is_remote_only():
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 94c2641bb4d2..868bd4fb57aa 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -40,6 +40,7 @@ import json
 import pickle
 from threading import Lock
 from inspect import signature, isclass
+import warnings
 
 import pyarrow as pa
 
@@ -49,6 +50,7 @@ from pyspark.sql.types import DataType
 
 import pyspark.sql.connect.proto as proto
 from pyspark.sql.column import Column
+from pyspark.sql.connect.proto import base_pb2 as 
spark_dot_connect_dot_base__pb2
 from pyspark.sql.connect.conversion import storage_level_to_proto
 from pyspark.sql.connect.expressions import Expression
 from pyspark.sql.connect.types import pyspark_types_to_proto_types, 
UnparsedDataType
@@ -62,6 +64,7 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.client import SparkConnectClient
     from pyspark.sql.connect.udf import UserDefinedFunction
     from pyspark.sql.connect.observation import Observation
+    from pyspark.sql.connect.session import SparkSession
 
 
 class LogicalPlan:
@@ -547,14 +550,49 @@ class CachedRemoteRelation(LogicalPlan):
     """Logical plan object for a DataFrame reference which represents a 
DataFrame that's been
     cached on the server with a given id."""
 
-    def __init__(self, relationId: str):
+    def __init__(self, relation_id: str, spark_session: "SparkSession"):
         super().__init__(None)
-        self._relationId = relationId
-
-    def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = self._create_proto_relation()
-        plan.cached_remote_relation.relation_id = self._relationId
-        return plan
+        self._relation_id = relation_id
+        # Needs to hold the session to make a request itself.
+        self._spark_session = spark_session
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        plan = self._create_proto_relation()
+        plan.cached_remote_relation.relation_id = self._relation_id
+        return plan
+
+    def __del__(self) -> None:
+        session = self._spark_session
+        # If session is already closed, all cached DataFrame should be 
released.
+        if session is not None and not session.client.is_closed and 
self._relation_id is not None:
+            try:
+                command = 
RemoveRemoteCachedRelation(self).command(session=session.client)
+                req = session.client._execute_plan_request_with_metadata()
+                if session.client._user_id:
+                    req.user_context.user_id = session.client._user_id
+                req.plan.command.CopyFrom(command)
+
+                for attempt in session.client._retrying():
+                    with attempt:
+                        # !!HACK ALERT!!
+                        # unary_stream does not work on Python's exit for an 
unknown reasons
+                        # Therefore, here we open unary_unary channel instead.
+                        # See also :class:`SparkConnectServiceStub`.
+                        request_serializer = (
+                            
spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString
+                        )
+                        response_deserializer = (
+                            
spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString
+                        )
+                        channel = session.client._channel.unary_unary(
+                            "/spark.connect.SparkConnectService/ExecutePlan",
+                            request_serializer=request_serializer,
+                            response_deserializer=response_deserializer,
+                        )
+                        metadata = session.client._builder.metadata()
+                        channel(req, metadata=metadata)  # type: 
ignore[arg-type]
+            except Exception as e:
+                warnings.warn(f"RemoveRemoteCachedRelation failed with 
exception: {e}.")
 
 
 class Hint(LogicalPlan):
@@ -1792,7 +1830,7 @@ class RemoveRemoteCachedRelation(LogicalPlan):
 
     def command(self, session: "SparkConnectClient") -> proto.Command:
         plan = self._create_proto_relation()
-        plan.cached_remote_relation.relation_id = self._relation._relationId
+        plan.cached_remote_relation.relation_id = self._relation._relation_id
         cmd = proto.Command()
         
cmd.remove_cached_remote_relation_command.relation.CopyFrom(plan.cached_remote_relation)
         return cmd
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 5e6c5e558764..f99d298ea117 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -926,7 +926,7 @@ class SparkSession:
         This is used in ForeachBatch() runner, where the remote DataFrame 
refers to the
         output of a micro batch.
         """
-        return DataFrame(CachedRemoteRelation(remote_id), self)
+        return DataFrame(CachedRemoteRelation(remote_id, spark_session=self), 
self)
 
     @staticmethod
     def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index b144c3b8de20..0648b5ce9925 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -16,10 +16,10 @@
 #
 
 import os
+import gc
 import unittest
 import shutil
 import tempfile
-import time
 
 from pyspark.util import is_remote_only
 from pyspark.errors import PySparkTypeError, PySparkValueError
@@ -34,6 +34,7 @@ from pyspark.sql.types import (
     ArrayType,
     Row,
 )
+from pyspark.testing.utils import eventually
 from pyspark.testing.sqlutils import SQLTestUtils
 from pyspark.testing.connectutils import (
     should_test_connect,
@@ -1379,8 +1380,8 @@ class SparkConnectGCTests(SparkConnectSQLTestCase):
         # SPARK-48258: Make sure garbage-collecting DataFrame remove the 
paired state
         # in Spark Connect server
         df = self.connect.range(10).localCheckpoint()
-        self.assertIsNotNone(df._cached_remote_relation_id)
-        cached_remote_relation_id = df._cached_remote_relation_id
+        self.assertIsNotNone(df._plan._relation_id)
+        cached_remote_relation_id = df._plan._relation_id
 
         jvm = self.spark._jvm
         session_holder = getattr(
@@ -1397,14 +1398,54 @@ class SparkConnectGCTests(SparkConnectSQLTestCase):
         )
 
         del df
+        gc.collect()
 
-        time.sleep(3)  # Make sure removing is triggered, and executed in the 
server.
+        def condition():
+            # Check the state was removed up on garbage-collection.
+            self.assertIsNone(
+                
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
+            )
+
+        eventually(catch_assertions=True)(condition)()
+
+    def test_garbage_collection_derived_checkpoint(self):
+        # SPARK-48258: Should keep the cached remote relation when derived 
DataFrames exist
+        df = self.connect.range(10).localCheckpoint()
+        self.assertIsNotNone(df._plan._relation_id)
+        derived = df.repartition(10)
+        cached_remote_relation_id = df._plan._relation_id
 
-        # Check the state was removed up on garbage-collection.
-        self.assertIsNone(
+        jvm = self.spark._jvm
+        session_holder = getattr(
+            getattr(
+                jvm.org.apache.spark.sql.connect.service,
+                "SparkConnectService$",
+            ),
+            "MODULE$",
+        ).getOrCreateIsolatedSession(self.connect.client._user_id, 
self.connect.client._session_id)
+
+        # Check the state exists.
+        self.assertIsNotNone(
             
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
         )
 
+        del df
+        gc.collect()
+
+        def condition():
+            self.assertIsNone(
+                
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
+            )
+
+        # Should not remove the cache
+        with self.assertRaises(AssertionError):
+            eventually(catch_assertions=True, timeout=5)(condition)()
+
+        del derived
+        gc.collect()
+
+        eventually(catch_assertions=True)(condition)()
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_connect_basic import *  # noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to