This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 05e1706e5aa6 [SPARK-48310][PYTHON][CONNECT] Cached properties must 
return copies
05e1706e5aa6 is described below

commit 05e1706e5aa66a592e61b03263683a2dbbc64afe
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Fri May 17 10:28:36 2024 +0900

    [SPARK-48310][PYTHON][CONNECT] Cached properties must return copies
    
    ### What changes were proposed in this pull request?
    When a consumer modifies the result values of a cached property it will 
modify the value of the cached property.
    
    Before:
    ```python
    df_columns = df.columns
    for col in ['id', 'name']:
      df_columns.remove(col)
    assert len(df_columns) == df.columns
    ```
    
    But this is wrong and this patch fixes it to
    
    ```python
    df_columns = df.columns
    for col in ['id', 'name']:
      df_columns.remove(col)
    assert len(df_columns) != df.columns
    ```
    
    ### Why are the changes needed?
    Correctness of the API
    
    ### Does this PR introduce _any_ user-facing change?
    No, this makes the code consistent with Spark classic.
    
    ### How was this patch tested?
    UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46621 from grundprinzip/grundprinzip/SPARK-48310.
    
    Authored-by: Martin Grund <martin.gr...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/dataframe.py            |  3 ++-
 .../sql/tests/connect/test_parity_dataframe.py     | 24 ++++++++++++++++++++++
 2 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index ccaaa15f3190..05300909cdce 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -43,6 +43,7 @@ from typing import (
     Type,
 )
 
+import copy
 import sys
 import random
 import pyarrow as pa
@@ -1787,7 +1788,7 @@ class DataFrame(ParentDataFrame):
         if self._cached_schema is None:
             query = self._plan.to_proto(self._session.client)
             self._cached_schema = self._session.client.schema(query)
-        return self._cached_schema
+        return copy.deepcopy(self._cached_schema)
 
     def isLocal(self) -> bool:
         query = self._plan.to_proto(self._session.client)
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py 
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 343f485553a9..c9888a6a8f1a 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -19,6 +19,7 @@ import unittest
 
 from pyspark.sql.tests.test_dataframe import DataFrameTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.sql.types import StructType, StructField, IntegerType, StringType
 
 
 class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
@@ -26,6 +27,29 @@ class DataFrameParityTests(DataFrameTestsMixin, 
ReusedConnectTestCase):
         df = self.spark.createDataFrame(data=[{"foo": "bar"}, {"foo": "baz"}])
         super().check_help_command(df)
 
+    def test_cached_property_is_copied(self):
+        schema = StructType(
+            [
+                StructField("id", IntegerType(), True),
+                StructField("name", StringType(), True),
+                StructField("age", IntegerType(), True),
+                StructField("city", StringType(), True),
+            ]
+        )
+        # Create some dummy data
+        data = [
+            (1, "Alice", 30, "New York"),
+            (2, "Bob", 25, "San Francisco"),
+            (3, "Cathy", 29, "Los Angeles"),
+            (4, "David", 35, "Chicago"),
+        ]
+        df = self.spark.createDataFrame(data, schema)
+        df_columns = df.columns
+        assert len(df.columns) == 4
+        for col in ["id", "name"]:
+            df_columns.remove(col)
+        assert len(df.columns) == 4
+
     @unittest.skip("Spark Connect does not support RDD but the tests depend on 
them.")
     def test_toDF_with_schema_string(self):
         super().test_toDF_with_schema_string()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to