This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 764cb3e07364 [SPARK-46750][CONNECT][PYTHON] DataFrame APIs code clean 
up
764cb3e07364 is described below

commit 764cb3e073644b9d543502d8951c47e41ba0f46b
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Jan 18 08:17:29 2024 +0800

    [SPARK-46750][CONNECT][PYTHON] DataFrame APIs code clean up
    
    ### What changes were proposed in this pull request?
    1, unify the import;
    2, delete unused helper functions and variables;
    
    ### Why are the changes needed?
    code clean up
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44771 from zhengruifeng/py_df_cleanup.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/dataframe.py | 38 ++++++++++-----------------------
 python/pyspark/sql/connect/group.py     | 15 ++++---------
 2 files changed, 15 insertions(+), 38 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 7ee27065208c..0cf6c0921f78 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -72,19 +72,12 @@ from pyspark.sql.connect.readwriter import DataFrameWriter, 
DataFrameWriterV2
 from pyspark.sql.connect.streaming.readwriter import DataStreamWriter
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import (
+    SortOrder,
     ColumnReference,
     UnresolvedRegex,
     UnresolvedStar,
 )
-from pyspark.sql.connect.functions.builtin import (
-    _to_col,
-    _invoke_function,
-    col,
-    lit,
-    udf,
-    struct,
-    expr as sql_expression,
-)
+from pyspark.sql.connect.functions import builtin as F
 from pyspark.sql.pandas.types import from_arrow_schema
 
 
@@ -199,9 +192,9 @@ class DataFrame:
             expr = expr[0]  # type: ignore[assignment]
         for element in expr:
             if isinstance(element, str):
-                sql_expr.append(sql_expression(element))
+                sql_expr.append(F.expr(element))
             else:
-                sql_expr.extend([sql_expression(e) for e in element])
+                sql_expr.extend([F.expr(e) for e in element])
 
         return DataFrame(plan.Project(self._plan, *sql_expr), 
session=self._session)
 
@@ -215,7 +208,7 @@ class DataFrame:
             )
 
         if len(exprs) == 1 and isinstance(exprs[0], dict):
-            measures = [_invoke_function(f, col(e)) for e, f in 
exprs[0].items()]
+            measures = [F._invoke_function(f, F.col(e)) for e, f in 
exprs[0].items()]
             return self.groupBy().agg(*measures)
         else:
             # other expressions
@@ -259,7 +252,7 @@ class DataFrame:
     sparkSession.__doc__ = PySparkDataFrame.sparkSession.__doc__
 
     def count(self) -> int:
-        table, _ = self.agg(_invoke_function("count", lit(1)))._to_table()
+        table, _ = self.agg(F._invoke_function("count", F.lit(1)))._to_table()
         return table[0][0].as_py()
 
     count.__doc__ = PySparkDataFrame.count.__doc__
@@ -352,8 +345,6 @@ class DataFrame:
         self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
     ) -> "DataFrame":
         def _convert_col(col: "ColumnOrName") -> "ColumnOrName":
-            from pyspark.sql.connect.expressions import SortOrder, 
ColumnReference
-
             if isinstance(col, Column):
                 if isinstance(col._expr, SortOrder):
                     return col
@@ -471,7 +462,7 @@ class DataFrame:
 
     def filter(self, condition: Union[Column, str]) -> "DataFrame":
         if isinstance(condition, str):
-            expr = sql_expression(condition)
+            expr = F.expr(condition)
         else:
             expr = condition
         return DataFrame(plan.Filter(child=self._plan, filter=expr), 
session=self._session)
@@ -713,7 +704,7 @@ class DataFrame:
                     )
             else:
                 _c = c  # type: ignore[assignment]
-            _cols.append(_to_col(cast("ColumnOrName", _c)))
+            _cols.append(F._to_col(cast("ColumnOrName", _c)))
 
         ascending = kwargs.get("ascending", True)
         if isinstance(ascending, (bool, int)):
@@ -1652,8 +1643,6 @@ class DataFrame:
     def sampleBy(
         self, col: "ColumnOrName", fractions: Dict[Any, float], seed: 
Optional[int] = None
     ) -> "DataFrame":
-        from pyspark.sql.connect.expressions import ColumnReference
-
         if isinstance(col, str):
             col = Column(ColumnReference(col))
         elif not isinstance(col, Column):
@@ -1754,7 +1743,7 @@ class DataFrame:
         elif isinstance(item, (list, tuple)):
             return self.select(*item)
         elif isinstance(item, int):
-            return col(self.columns[item])
+            return F.col(self.columns[item])
         else:
             raise PySparkTypeError(
                 error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
@@ -1768,11 +1757,6 @@ class DataFrame:
 
     __dir__.__doc__ = PySparkDataFrame.__dir__.__doc__
 
-    def _print_plan(self) -> str:
-        if self._plan:
-            return self._plan.print()
-        return ""
-
     def collect(self) -> List[Row]:
         table, schema = self._to_table()
 
@@ -2084,8 +2068,8 @@ class DataFrame:
         def foreach_func(row: Any) -> None:
             f(row)
 
-        self.select(struct(*self.schema.fieldNames()).alias("row")).select(
-            udf(foreach_func, StructType())("row")  # type: ignore[arg-type]
+        self.select(F.struct(*self.schema.fieldNames()).alias("row")).select(
+            F.udf(foreach_func, StructType())("row")  # type: ignore[arg-type]
         ).collect()
 
     foreach.__doc__ = PySparkDataFrame.foreach.__doc__
diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index 2ccd7463b9e0..db4c9f57c5c2 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -40,7 +40,7 @@ from pyspark.sql.types import StructType
 
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.functions.builtin import _invoke_function, col, lit
+from pyspark.sql.connect.functions import builtin as F
 from pyspark.errors import PySparkNotImplementedError, PySparkTypeError
 
 if TYPE_CHECKING:
@@ -132,7 +132,7 @@ class GroupedData:
         assert exprs, "exprs should not be empty"
         if len(exprs) == 1 and isinstance(exprs[0], dict):
             # Convert the dict into key value pairs
-            aggregate_cols = [_invoke_function(exprs[0][k], col(k)) for k in 
exprs[0]]
+            aggregate_cols = [F._invoke_function(exprs[0][k], F.col(k)) for k 
in exprs[0]]
         else:
             # Columns
             assert all(isinstance(c, Column) for c in exprs), "all exprs 
should be Column"
@@ -166,8 +166,6 @@ class GroupedData:
             field.name for field in schema.fields if 
isinstance(field.dataType, NumericType)
         ]
 
-        agg_cols: List[str] = []
-
         if len(cols) > 0:
             invalid_cols = [c for c in cols if c not in numerical_cols]
             if len(invalid_cols) > 0:
@@ -185,7 +183,7 @@ class GroupedData:
                 child=self._df._plan,
                 group_type=self._group_type,
                 grouping_cols=self._grouping_cols,
-                aggregate_cols=[_invoke_function(function, col(c)) for c in 
agg_cols],
+                aggregate_cols=[F._invoke_function(function, F.col(c)) for c 
in agg_cols],
                 pivot_col=self._pivot_col,
                 pivot_values=self._pivot_values,
                 grouping_sets=self._grouping_sets,
@@ -216,7 +214,7 @@ class GroupedData:
     mean = avg
 
     def count(self) -> "DataFrame":
-        return self.agg(_invoke_function("count", lit(1)).alias("count"))
+        return self.agg(F._invoke_function("count", F.lit(1)).alias("count"))
 
     count.__doc__ = PySparkGroupedData.count.__doc__
 
@@ -444,11 +442,6 @@ class PandasCogroupedOps:
 
     applyInArrow.__doc__ = PySparkPandasCogroupedOps.applyInArrow.__doc__
 
-    @staticmethod
-    def _extract_cols(gd: "GroupedData") -> List[Column]:
-        df = gd._df
-        return [df[col] for col in df.columns]
-
 
 PandasCogroupedOps.__doc__ = PySparkPandasCogroupedOps.__doc__
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to