This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 43ac3db1e27 [SPARK-44528][CONNECT] Support proper usage of hasattr() for Connect dataframe 43ac3db1e27 is described below commit 43ac3db1e27a4169183a90b54b6a873f0d26a7ba Author: Martin Grund <martin.gr...@databricks.com> AuthorDate: Thu Jul 27 08:53:45 2023 +0900 [SPARK-44528][CONNECT] Support proper usage of hasattr() for Connect dataframe ### What changes were proposed in this pull request? Currently Connect does not allow the proper usage of Python's `hasattr()` to identify if an attribute is defined or not. This patch fixes that bug (it's working in regular PySpark). ### Why are the changes needed? Bugfix ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #42132 from grundprinzip/SPARK-44528. Lead-authored-by: Martin Grund <martin.gr...@databricks.com> Co-authored-by: Martin Grund <grundprin...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 91e97f92fe76f9718cd16af0c761d5530bdb37ee) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/dataframe.py | 8 +++++++ .../sql/tests/connect/test_connect_basic.py | 17 +++++++++++-- python/pyspark/testing/connectutils.py | 28 ++++++++++++++++++---- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 6429645f0e0..12e424b5ef1 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1584,8 +1584,16 @@ class DataFrame: error_class="NOT_IMPLEMENTED", message_parameters={"feature": f"{name}()"}, ) + + if name not in self.columns: + raise AttributeError( + "'%s' object has no attribute '%s'" % (self.__class__.__name__, name) + ) + return self[name] + __getattr__.__doc__ = PySparkDataFrame.__getattr__.__doc__ + @overload def __getitem__(self, item: Union[int, str]) -> Column: ... diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 5259ea6b5f5..065f1585a9f 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -157,6 +157,19 @@ class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSpark class SparkConnectBasicTests(SparkConnectSQLTestCase): + def test_df_getattr_behavior(self): + cdf = self.connect.range(10) + sdf = self.spark.range(10) + + sdf._simple_extension = 10 + cdf._simple_extension = 10 + + self.assertEqual(sdf._simple_extension, cdf._simple_extension) + self.assertEqual(type(sdf._simple_extension), type(cdf._simple_extension)) + + self.assertTrue(hasattr(cdf, "_simple_extension")) + self.assertFalse(hasattr(cdf, "_simple_extension_does_not_exsit")) + def test_df_get_item(self): # SPARK-41779: test __getitem__ @@ -1296,8 +1309,8 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): sdf.drop("a", "x").toPandas(), ) self.assert_eq( - cdf.drop(cdf.a, cdf.x).toPandas(), - sdf.drop("a", "x").toPandas(), + cdf.drop(cdf.a, "x").toPandas(), + sdf.drop(sdf.a, "x").toPandas(), ) def test_subquery_alias(self) -> None: diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 1b3ac10fce8..b6145d0a006 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -16,6 +16,7 @@ # import shutil import tempfile +import types import typing import os import functools @@ -67,7 +68,7 @@ should_test_connect: str = typing.cast(str, connect_requirement_message is None) if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame - from pyspark.sql.connect.plan import Read, Range, SQL + from pyspark.sql.connect.plan import Read, Range, SQL, LogicalPlan from pyspark.sql.connect.session import SparkSession @@ -88,16 +89,33 @@ class MockRemoteSession: return functools.partial(self.hooks[item]) +class MockDF(DataFrame): + """Helper class that must only be used for the mock plan tests.""" + + def __init__(self, session: SparkSession, plan: LogicalPlan): + super().__init__(session) + self._plan = plan + + def __getattr__(self, name): + """All attributes are resolved to columns, because none really exist in the + mocked DataFrame.""" + return self[name] + + @unittest.skipIf(not should_test_connect, connect_requirement_message) class PlanOnlyTestFixture(unittest.TestCase, PySparkErrorTestUtils): @classmethod def _read_table(cls, table_name): - return DataFrame.withPlan(Read(table_name), cls.connect) + return cls._df_mock(Read(table_name)) @classmethod def _udf_mock(cls, *args, **kwargs): return "internal_name" + @classmethod + def _df_mock(cls, plan: LogicalPlan) -> MockDF: + return MockDF(cls.connect, plan) + @classmethod def _session_range( cls, @@ -106,17 +124,17 @@ class PlanOnlyTestFixture(unittest.TestCase, PySparkErrorTestUtils): step=1, num_partitions=None, ): - return DataFrame.withPlan(Range(start, end, step, num_partitions), cls.connect) + return cls._df_mock(Range(start, end, step, num_partitions)) @classmethod def _session_sql(cls, query): - return DataFrame.withPlan(SQL(query), cls.connect) + return cls._df_mock(SQL(query)) if have_pandas: @classmethod def _with_plan(cls, plan): - return DataFrame.withPlan(plan, cls.connect) + return cls._df_mock(plan) @classmethod def setUpClass(cls): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org