This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new aea13fca5d57 [SPARK-47500][PYTHON][CONNECT] Factor column name 
handling out of `plan.py`
aea13fca5d57 is described below

commit aea13fca5d5794d03a396036f838584cc9c46986
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Fri Mar 22 15:09:31 2024 +0900

    [SPARK-47500][PYTHON][CONNECT] Factor column name handling out of `plan.py`
    
    ### What changes were proposed in this pull request?
    Factor column name handling out of `plan.py`
    
    ### Why are the changes needed?
    there are too many parameters preprocessing in `plan.py`, e.g. the column 
name handling,
    there are multiple duplicated helper functions here and there, make it hard 
to follow some times.
    
    ### Does this PR introduce _any_ user-facing change?
    no, just code refactor
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #45636 from zhengruifeng/plan_clean_up.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/dataframe.py  |  70 ++++++++--------
 python/pyspark/sql/connect/plan.py       | 136 +++++++------------------------
 python/pyspark/sql/connect/readwriter.py |   6 +-
 3 files changed, 71 insertions(+), 141 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 741606c89aa4..2a22d02387ae 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -182,8 +182,10 @@ class DataFrame:
     def select(self, *cols: "ColumnOrName") -> "DataFrame":
         if len(cols) == 1 and isinstance(cols[0], list):
             cols = cols[0]
-
-        return DataFrame(plan.Project(self._plan, *cols), 
session=self._session)
+        return DataFrame(
+            plan.Project(self._plan, [F._to_col(c) for c in cols]),
+            session=self._session,
+        )
 
     select.__doc__ = PySparkDataFrame.select.__doc__
 
@@ -197,7 +199,7 @@ class DataFrame:
             else:
                 sql_expr.extend([F.expr(e) for e in element])
 
-        return DataFrame(plan.Project(self._plan, *sql_expr), 
session=self._session)
+        return DataFrame(plan.Project(self._plan, sql_expr), 
session=self._session)
 
     selectExpr.__doc__ = PySparkDataFrame.selectExpr.__doc__
 
@@ -309,18 +311,20 @@ class DataFrame:
                 )
             if len(cols) == 0:
                 return DataFrame(
-                    plan.Repartition(self._plan, num_partitions=numPartitions, 
shuffle=True),
+                    plan.Repartition(self._plan, numPartitions, shuffle=True),
                     self._session,
                 )
             else:
                 return DataFrame(
-                    plan.RepartitionByExpression(self._plan, numPartitions, 
list(cols)),
+                    plan.RepartitionByExpression(
+                        self._plan, numPartitions, [F._to_col(c) for c in cols]
+                    ),
                     self.sparkSession,
                 )
         elif isinstance(numPartitions, (str, Column)):
             cols = (numPartitions,) + cols
             return DataFrame(
-                plan.RepartitionByExpression(self._plan, None, list(cols)),
+                plan.RepartitionByExpression(self._plan, None, [F._to_col(c) 
for c in cols]),
                 self.sparkSession,
             )
         else:
@@ -345,14 +349,14 @@ class DataFrame:
     def repartitionByRange(  # type: ignore[misc]
         self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
     ) -> "DataFrame":
-        def _convert_col(col: "ColumnOrName") -> "ColumnOrName":
+        def _convert_col(col: "ColumnOrName") -> Column:
             if isinstance(col, Column):
                 if isinstance(col._expr, SortOrder):
                     return col
                 else:
-                    return Column(SortOrder(col._expr))
+                    return col.asc()
             else:
-                return Column(SortOrder(ColumnReference(col)))
+                return F.col(col).asc()
 
         if isinstance(numPartitions, int):
             if not numPartitions > 0:
@@ -369,18 +373,17 @@ class DataFrame:
                     message_parameters={"item": "cols"},
                 )
             else:
-                sort = []
-                sort.extend([_convert_col(c) for c in cols])
                 return DataFrame(
-                    plan.RepartitionByExpression(self._plan, numPartitions, 
sort),
+                    plan.RepartitionByExpression(
+                        self._plan, numPartitions, [_convert_col(c) for c in 
cols]
+                    ),
                     self.sparkSession,
                 )
         elif isinstance(numPartitions, (str, Column)):
-            cols = (numPartitions,) + cols
-            sort = []
-            sort.extend([_convert_col(c) for c in cols])
             return DataFrame(
-                plan.RepartitionByExpression(self._plan, None, sort),
+                plan.RepartitionByExpression(
+                    self._plan, None, [_convert_col(c) for c in 
[numPartitions] + list(cols)]
+                ),
                 self.sparkSession,
             )
         else:
@@ -648,12 +651,18 @@ class DataFrame:
         if tolerance is not None:
             assert isinstance(tolerance, Column), "tolerance should be Column"
 
+        def _convert_col(df: "DataFrame", col: "ColumnOrName") -> Column:
+            if isinstance(col, Column):
+                return col
+            else:
+                return Column(ColumnReference(col, df._plan._plan_id))
+
         return DataFrame(
             plan.AsOfJoin(
                 left=self._plan,
                 right=other._plan,
-                left_as_of=leftAsOfColumn,
-                right_as_of=rightAsOfColumn,
+                left_as_of=_convert_col(self, leftAsOfColumn),
+                right_as_of=_convert_col(other, rightAsOfColumn),
                 on=on,
                 how=how,
                 tolerance=tolerance,
@@ -940,24 +949,21 @@ class DataFrame:
     ) -> "DataFrame":
         assert ids is not None, "ids must not be None"
 
-        def to_jcols(
+        def _convert_cols(
             cols: Optional[Union["ColumnOrName", List["ColumnOrName"], 
Tuple["ColumnOrName", ...]]]
-        ) -> List["ColumnOrName"]:
+        ) -> List[Column]:
             if cols is None:
-                lst = []
-            elif isinstance(cols, tuple):
-                lst = list(cols)
-            elif isinstance(cols, list):
-                lst = cols
+                return []
+            elif isinstance(cols, (tuple, list)):
+                return [F._to_col(c) for c in cols]
             else:
-                lst = [cols]
-            return lst
+                return [F._to_col(cols)]
 
         return DataFrame(
             plan.Unpivot(
                 self._plan,
-                to_jcols(ids),
-                to_jcols(values) if values is not None else None,
+                _convert_cols(ids),
+                _convert_cols(values) if values is not None else None,
                 variableColumnName,
                 valueColumnName,
             ),
@@ -1645,9 +1651,7 @@ class DataFrame:
     def sampleBy(
         self, col: "ColumnOrName", fractions: Dict[Any, float], seed: 
Optional[int] = None
     ) -> "DataFrame":
-        if isinstance(col, str):
-            col = Column(ColumnReference(col))
-        elif not isinstance(col, Column):
+        if not isinstance(col, (str, Column)):
             raise PySparkTypeError(
                 error_class="NOT_COLUMN_OR_STR",
                 message_parameters={"arg_name": "col", "arg_type": 
type(col).__name__},
@@ -1671,7 +1675,7 @@ class DataFrame:
             fractions[k] = float(v)
         seed = seed if seed is not None else random.randint(0, sys.maxsize)
         return DataFrame(
-            plan.StatSampleBy(child=self._plan, col=col, fractions=fractions, 
seed=seed),
+            plan.StatSampleBy(child=self._plan, col=F._to_col(col), 
fractions=fractions, seed=seed),
             session=self._session,
         )
 
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 7a93580aa112..fc186b3a3df7 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -37,7 +37,6 @@ from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import (
     Expression,
     SortOrder,
-    ColumnReference,
     LiteralExpression,
 )
 from pyspark.sql.connect.types import pyspark_types_to_proto_types, 
UnparsedDataType
@@ -49,7 +48,6 @@ from pyspark.errors import (
 )
 
 if TYPE_CHECKING:
-    from pyspark.sql.connect._typing import ColumnOrName
     from pyspark.sql.connect.client import SparkConnectClient
     from pyspark.sql.connect.udf import UserDefinedFunction
     from pyspark.sql.connect.observation import Observation
@@ -77,22 +75,6 @@ class LogicalPlan:
         plan.common.plan_id = self._plan_id
         return plan
 
-    def unresolved_attr(self, colName: str) -> proto.Expression:
-        """Creates an unresolved attribute from a column name."""
-        exp = proto.Expression()
-        exp.unresolved_attribute.unparsed_identifier = colName
-        return exp
-
-    def to_attr_or_expression(
-        self, col: "ColumnOrName", session: "SparkConnectClient"
-    ) -> proto.Expression:
-        """Returns either an instance of an unresolved attribute or the 
serialized
-        expression value of the column."""
-        if type(col) is str:
-            return self.unresolved_attr(col)
-        else:
-            return cast(Column, col).to_plan(session)
-
     def plan(self, session: "SparkConnectClient") -> proto.Relation:  # type: 
ignore[empty-body]
         ...
 
@@ -465,35 +447,20 @@ class Project(LogicalPlan):
 
     """
 
-    def __init__(self, child: Optional["LogicalPlan"], *columns: 
"ColumnOrName") -> None:
+    def __init__(
+        self,
+        child: Optional["LogicalPlan"],
+        columns: List[Column],
+    ) -> None:
         super().__init__(child)
-        self._columns = list(columns)
-        self._verify_expressions()
-
-    def _verify_expressions(self) -> None:
-        """Ensures that all input arguments are instances of Expression or 
String."""
-        for c in self._columns:
-            if not isinstance(c, (Column, str)):
-                raise PySparkTypeError(
-                    error_class="NOT_LIST_OF_COLUMN_OR_STR",
-                    message_parameters={"arg_name": "columns"},
-                )
+        assert all(isinstance(c, Column) for c in columns)
+        self._columns = columns
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        from pyspark.sql.connect.functions import col
-
         assert self._child is not None
         plan = self._create_proto_relation()
         plan.project.input.CopyFrom(self._child.plan(session))
-
-        proj_exprs = []
-        for c in self._columns:
-            if isinstance(c, Column):
-                proj_exprs.append(c.to_plan(session))
-            else:
-                proj_exprs.append(col(c).to_plan(session))
-
-        plan.project.expressions.extend(proj_exprs)
+        plan.project.expressions.extend([c.to_plan(session) for c in 
self._columns])
         return plan
 
 
@@ -896,7 +863,7 @@ class Join(LogicalPlan):
                 if isinstance(self.on, str):
                     plan.join.using_columns.append(self.on)
                 else:
-                    
plan.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session))
+                    plan.join.join_condition.CopyFrom(self.on.to_plan(session))
             elif len(self.on) > 0:
                 if isinstance(self.on[0], str):
                     plan.join.using_columns.extend(cast(str, self.on))
@@ -936,8 +903,8 @@ class AsOfJoin(LogicalPlan):
         self,
         left: LogicalPlan,
         right: LogicalPlan,
-        left_as_of: "ColumnOrName",
-        right_as_of: "ColumnOrName",
+        left_as_of: Column,
+        right_as_of: Column,
         on: Optional[Union[str, List[str], Column, List[Column]]],
         how: str,
         tolerance: Optional[Column],
@@ -960,19 +927,8 @@ class AsOfJoin(LogicalPlan):
         plan.as_of_join.left.CopyFrom(self.left.plan(session))
         plan.as_of_join.right.CopyFrom(self.right.plan(session))
 
-        if isinstance(self.left_as_of, Column):
-            
plan.as_of_join.left_as_of.CopyFrom(self.left_as_of.to_plan(session))
-        else:
-            plan.as_of_join.left_as_of.CopyFrom(
-                ColumnReference(self.left_as_of, 
self.left._plan_id).to_plan(session)
-            )
-
-        if isinstance(self.right_as_of, Column):
-            
plan.as_of_join.right_as_of.CopyFrom(self.right_as_of.to_plan(session))
-        else:
-            plan.as_of_join.right_as_of.CopyFrom(
-                ColumnReference(self.right_as_of, 
self.right._plan_id).to_plan(session)
-            )
+        plan.as_of_join.left_as_of.CopyFrom(self.left_as_of.to_plan(session))
+        plan.as_of_join.right_as_of.CopyFrom(self.right_as_of.to_plan(session))
 
         if self.on is not None:
             if not isinstance(self.on, list):
@@ -1128,26 +1084,18 @@ class RepartitionByExpression(LogicalPlan):
         self,
         child: Optional["LogicalPlan"],
         num_partitions: Optional[int],
-        columns: List["ColumnOrName"],
+        columns: List[Column],
     ) -> None:
         super().__init__(child)
         self.num_partitions = num_partitions
+        assert all(isinstance(c, Column) for c in columns)
         self.columns = columns
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         plan = self._create_proto_relation()
-
-        part_exprs = []
-        for c in self.columns:
-            if isinstance(c, Column):
-                part_exprs.append(c.to_plan(session))
-            elif c == "*":
-                exp = proto.Expression()
-                exp.unresolved_star.SetInParent()
-                part_exprs.append(exp)
-            else:
-                part_exprs.append(self.unresolved_attr(c))
-        plan.repartition_by_expression.partition_exprs.extend(part_exprs)
+        plan.repartition_by_expression.partition_exprs.extend(
+            [c.to_plan(session) for c in self.columns]
+        )
 
         if self._child is not None:
             
plan.repartition_by_expression.input.CopyFrom(self._child.plan(session))
@@ -1283,8 +1231,8 @@ class Unpivot(LogicalPlan):
     def __init__(
         self,
         child: Optional["LogicalPlan"],
-        ids: List["ColumnOrName"],
-        values: Optional[List["ColumnOrName"]],
+        ids: List[Column],
+        values: Optional[List[Column]],
         variable_column_name: str,
         value_column_name: str,
     ) -> None:
@@ -1294,19 +1242,13 @@ class Unpivot(LogicalPlan):
         self.variable_column_name = variable_column_name
         self.value_column_name = value_column_name
 
-    def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") 
-> proto.Expression:
-        if isinstance(col, Column):
-            return col.to_plan(session)
-        else:
-            return self.unresolved_attr(col)
-
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = self._create_proto_relation()
         plan.unpivot.input.CopyFrom(self._child.plan(session))
-        plan.unpivot.ids.extend([self.col_to_expr(x, session) for x in 
self.ids])
+        plan.unpivot.ids.extend([id.to_plan(session) for id in self.ids])
         if self.values is not None:
-            plan.unpivot.values.values.extend([self.col_to_expr(x, session) 
for x in self.values])
+            plan.unpivot.values.values.extend([v.to_plan(session) for v in 
self.values])
         plan.unpivot.variable_column_name = self.variable_column_name
         plan.unpivot.value_column_name = self.value_column_name
         return plan
@@ -1319,18 +1261,13 @@ class CollectMetrics(LogicalPlan):
         self,
         child: Optional["LogicalPlan"],
         observation: Union[str, "Observation"],
-        exprs: List["ColumnOrName"],
+        exprs: List[Column],
     ) -> None:
         super().__init__(child)
         self._observation = observation
+        assert all(isinstance(e, Column) for e in exprs)
         self._exprs = exprs
 
-    def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") 
-> proto.Expression:
-        if isinstance(col, Column):
-            return col.to_plan(session)
-        else:
-            return self.unresolved_attr(col)
-
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = self._create_proto_relation()
@@ -1340,7 +1277,7 @@ class CollectMetrics(LogicalPlan):
             if isinstance(self._observation, str)
             else str(self._observation._name)
         )
-        plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for 
x in self._exprs])
+        plan.collect_metrics.metrics.extend([e.to_plan(session) for e in 
self._exprs])
         return plan
 
     @property
@@ -1570,7 +1507,7 @@ class StatSampleBy(LogicalPlan):
     def __init__(
         self,
         child: Optional["LogicalPlan"],
-        col: "ColumnOrName",
+        col: Column,
         fractions: Dict[Any, float],
         seed: Optional[int],
     ) -> None:
@@ -1584,13 +1521,8 @@ class StatSampleBy(LogicalPlan):
 
         assert seed is None or isinstance(seed, int)
 
-        if isinstance(col, Column):
-            self._col = col
-        else:
-            self._col = Column(ColumnReference(col))
-
+        self._col = col
         self._fractions = fractions
-
         self._seed = seed
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
@@ -1767,17 +1699,11 @@ class WriteOperationV2(LogicalPlan):
         super(WriteOperationV2, self).__init__(child)
         self.table_name: Optional[str] = table_name
         self.provider: Optional[str] = None
-        self.partitioning_columns: List["ColumnOrName"] = []
+        self.partitioning_columns: List[Column] = []
         self.options: dict[str, Optional[str]] = {}
         self.table_properties: dict[str, Optional[str]] = {}
         self.mode: Optional[str] = None
-        self.overwrite_condition: Optional["ColumnOrName"] = None
-
-    def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") 
-> proto.Expression:
-        if isinstance(col, Column):
-            return col.to_plan(session)
-        else:
-            return self.unresolved_attr(col)
+        self.overwrite_condition: Optional[Column] = None
 
     def command(self, session: "SparkConnectClient") -> proto.Command:
         assert self._child is not None
@@ -1789,7 +1715,7 @@ class WriteOperationV2(LogicalPlan):
             plan.write_operation_v2.provider = self.provider
 
         plan.write_operation_v2.partitioning_columns.extend(
-            [self.col_to_expr(x, session) for x in self.partitioning_columns]
+            [c.to_plan(session) for c in self.partitioning_columns]
         )
 
         for k in self.options:
@@ -1818,7 +1744,7 @@ class WriteOperationV2(LogicalPlan):
                 plan.write_operation_v2.mode = 
proto.WriteOperationV2.Mode.MODE_REPLACE
                 if self.overwrite_condition is not None:
                     plan.write_operation_v2.overwrite_condition.CopyFrom(
-                        self.col_to_expr(self.overwrite_condition, session)
+                        self.overwrite_condition.to_plan(session)
                     )
             elif wm == "create_or_replace":
                 plan.write_operation_v2.mode = 
proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE
diff --git a/python/pyspark/sql/connect/readwriter.py 
b/python/pyspark/sql/connect/readwriter.py
index 51698f262fc5..0e9c9128bdbf 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -31,6 +31,7 @@ from pyspark.sql.readwriter import (
     DataFrameWriterV2 as PySparkDataFrameWriterV2,
 )
 from pyspark.errors import PySparkAttributeError, PySparkTypeError, 
PySparkValueError
+from pyspark.sql.connect.functions import builtin as F
 
 if TYPE_CHECKING:
     from pyspark.sql.connect.dataframe import DataFrame
@@ -876,8 +877,7 @@ class DataFrameWriterV2(OptionUtils):
     tableProperty.__doc__ = PySparkDataFrameWriterV2.tableProperty.__doc__
 
     def partitionedBy(self, col: "ColumnOrName", *cols: "ColumnOrName") -> 
"DataFrameWriterV2":
-        self._write.partitioning_columns = [col]
-        self._write.partitioning_columns.extend(cols)
+        self._write.partitioning_columns = [F._to_col(c) for c in [col] + 
list(cols)]
         return self
 
     partitionedBy.__doc__ = PySparkDataFrameWriterV2.partitionedBy.__doc__
@@ -916,7 +916,7 @@ class DataFrameWriterV2(OptionUtils):
 
     def overwrite(self, condition: "ColumnOrName") -> None:
         self._write.mode = "overwrite"
-        self._write.overwrite_condition = condition
+        self._write.overwrite_condition = F._to_col(condition)
         self._spark.client.execute_command(
             self._write.command(self._spark.client), self._write.observations
         )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to