This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new aea13fca5d57 [SPARK-47500][PYTHON][CONNECT] Factor column name handling out of `plan.py` aea13fca5d57 is described below commit aea13fca5d5794d03a396036f838584cc9c46986 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Fri Mar 22 15:09:31 2024 +0900 [SPARK-47500][PYTHON][CONNECT] Factor column name handling out of `plan.py` ### What changes were proposed in this pull request? Factor column name handling out of `plan.py` ### Why are the changes needed? there are too many parameters preprocessing in `plan.py`, e.g. the column name handling, there are multiple duplicated helper functions here and there, make it hard to follow some times. ### Does this PR introduce _any_ user-facing change? no, just code refactor ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #45636 from zhengruifeng/plan_clean_up. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/dataframe.py | 70 ++++++++-------- python/pyspark/sql/connect/plan.py | 136 +++++++------------------------ python/pyspark/sql/connect/readwriter.py | 6 +- 3 files changed, 71 insertions(+), 141 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 741606c89aa4..2a22d02387ae 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -182,8 +182,10 @@ class DataFrame: def select(self, *cols: "ColumnOrName") -> "DataFrame": if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] - - return DataFrame(plan.Project(self._plan, *cols), session=self._session) + return DataFrame( + plan.Project(self._plan, [F._to_col(c) for c in cols]), + session=self._session, + ) select.__doc__ = PySparkDataFrame.select.__doc__ @@ -197,7 +199,7 @@ class DataFrame: else: sql_expr.extend([F.expr(e) for e in element]) - return DataFrame(plan.Project(self._plan, *sql_expr), session=self._session) + return DataFrame(plan.Project(self._plan, sql_expr), session=self._session) selectExpr.__doc__ = PySparkDataFrame.selectExpr.__doc__ @@ -309,18 +311,20 @@ class DataFrame: ) if len(cols) == 0: return DataFrame( - plan.Repartition(self._plan, num_partitions=numPartitions, shuffle=True), + plan.Repartition(self._plan, numPartitions, shuffle=True), self._session, ) else: return DataFrame( - plan.RepartitionByExpression(self._plan, numPartitions, list(cols)), + plan.RepartitionByExpression( + self._plan, numPartitions, [F._to_col(c) for c in cols] + ), self.sparkSession, ) elif isinstance(numPartitions, (str, Column)): cols = (numPartitions,) + cols return DataFrame( - plan.RepartitionByExpression(self._plan, None, list(cols)), + plan.RepartitionByExpression(self._plan, None, [F._to_col(c) for c in cols]), self.sparkSession, ) else: @@ -345,14 +349,14 @@ class DataFrame: def repartitionByRange( # type: ignore[misc] self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName" ) -> "DataFrame": - def _convert_col(col: "ColumnOrName") -> "ColumnOrName": + def _convert_col(col: "ColumnOrName") -> Column: if isinstance(col, Column): if isinstance(col._expr, SortOrder): return col else: - return Column(SortOrder(col._expr)) + return col.asc() else: - return Column(SortOrder(ColumnReference(col))) + return F.col(col).asc() if isinstance(numPartitions, int): if not numPartitions > 0: @@ -369,18 +373,17 @@ class DataFrame: message_parameters={"item": "cols"}, ) else: - sort = [] - sort.extend([_convert_col(c) for c in cols]) return DataFrame( - plan.RepartitionByExpression(self._plan, numPartitions, sort), + plan.RepartitionByExpression( + self._plan, numPartitions, [_convert_col(c) for c in cols] + ), self.sparkSession, ) elif isinstance(numPartitions, (str, Column)): - cols = (numPartitions,) + cols - sort = [] - sort.extend([_convert_col(c) for c in cols]) return DataFrame( - plan.RepartitionByExpression(self._plan, None, sort), + plan.RepartitionByExpression( + self._plan, None, [_convert_col(c) for c in [numPartitions] + list(cols)] + ), self.sparkSession, ) else: @@ -648,12 +651,18 @@ class DataFrame: if tolerance is not None: assert isinstance(tolerance, Column), "tolerance should be Column" + def _convert_col(df: "DataFrame", col: "ColumnOrName") -> Column: + if isinstance(col, Column): + return col + else: + return Column(ColumnReference(col, df._plan._plan_id)) + return DataFrame( plan.AsOfJoin( left=self._plan, right=other._plan, - left_as_of=leftAsOfColumn, - right_as_of=rightAsOfColumn, + left_as_of=_convert_col(self, leftAsOfColumn), + right_as_of=_convert_col(other, rightAsOfColumn), on=on, how=how, tolerance=tolerance, @@ -940,24 +949,21 @@ class DataFrame: ) -> "DataFrame": assert ids is not None, "ids must not be None" - def to_jcols( + def _convert_cols( cols: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]] - ) -> List["ColumnOrName"]: + ) -> List[Column]: if cols is None: - lst = [] - elif isinstance(cols, tuple): - lst = list(cols) - elif isinstance(cols, list): - lst = cols + return [] + elif isinstance(cols, (tuple, list)): + return [F._to_col(c) for c in cols] else: - lst = [cols] - return lst + return [F._to_col(cols)] return DataFrame( plan.Unpivot( self._plan, - to_jcols(ids), - to_jcols(values) if values is not None else None, + _convert_cols(ids), + _convert_cols(values) if values is not None else None, variableColumnName, valueColumnName, ), @@ -1645,9 +1651,7 @@ class DataFrame: def sampleBy( self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None ) -> "DataFrame": - if isinstance(col, str): - col = Column(ColumnReference(col)) - elif not isinstance(col, Column): + if not isinstance(col, (str, Column)): raise PySparkTypeError( error_class="NOT_COLUMN_OR_STR", message_parameters={"arg_name": "col", "arg_type": type(col).__name__}, @@ -1671,7 +1675,7 @@ class DataFrame: fractions[k] = float(v) seed = seed if seed is not None else random.randint(0, sys.maxsize) return DataFrame( - plan.StatSampleBy(child=self._plan, col=col, fractions=fractions, seed=seed), + plan.StatSampleBy(child=self._plan, col=F._to_col(col), fractions=fractions, seed=seed), session=self._session, ) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 7a93580aa112..fc186b3a3df7 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -37,7 +37,6 @@ from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import ( Expression, SortOrder, - ColumnReference, LiteralExpression, ) from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType @@ -49,7 +48,6 @@ from pyspark.errors import ( ) if TYPE_CHECKING: - from pyspark.sql.connect._typing import ColumnOrName from pyspark.sql.connect.client import SparkConnectClient from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.observation import Observation @@ -77,22 +75,6 @@ class LogicalPlan: plan.common.plan_id = self._plan_id return plan - def unresolved_attr(self, colName: str) -> proto.Expression: - """Creates an unresolved attribute from a column name.""" - exp = proto.Expression() - exp.unresolved_attribute.unparsed_identifier = colName - return exp - - def to_attr_or_expression( - self, col: "ColumnOrName", session: "SparkConnectClient" - ) -> proto.Expression: - """Returns either an instance of an unresolved attribute or the serialized - expression value of the column.""" - if type(col) is str: - return self.unresolved_attr(col) - else: - return cast(Column, col).to_plan(session) - def plan(self, session: "SparkConnectClient") -> proto.Relation: # type: ignore[empty-body] ... @@ -465,35 +447,20 @@ class Project(LogicalPlan): """ - def __init__(self, child: Optional["LogicalPlan"], *columns: "ColumnOrName") -> None: + def __init__( + self, + child: Optional["LogicalPlan"], + columns: List[Column], + ) -> None: super().__init__(child) - self._columns = list(columns) - self._verify_expressions() - - def _verify_expressions(self) -> None: - """Ensures that all input arguments are instances of Expression or String.""" - for c in self._columns: - if not isinstance(c, (Column, str)): - raise PySparkTypeError( - error_class="NOT_LIST_OF_COLUMN_OR_STR", - message_parameters={"arg_name": "columns"}, - ) + assert all(isinstance(c, Column) for c in columns) + self._columns = columns def plan(self, session: "SparkConnectClient") -> proto.Relation: - from pyspark.sql.connect.functions import col - assert self._child is not None plan = self._create_proto_relation() plan.project.input.CopyFrom(self._child.plan(session)) - - proj_exprs = [] - for c in self._columns: - if isinstance(c, Column): - proj_exprs.append(c.to_plan(session)) - else: - proj_exprs.append(col(c).to_plan(session)) - - plan.project.expressions.extend(proj_exprs) + plan.project.expressions.extend([c.to_plan(session) for c in self._columns]) return plan @@ -896,7 +863,7 @@ class Join(LogicalPlan): if isinstance(self.on, str): plan.join.using_columns.append(self.on) else: - plan.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session)) + plan.join.join_condition.CopyFrom(self.on.to_plan(session)) elif len(self.on) > 0: if isinstance(self.on[0], str): plan.join.using_columns.extend(cast(str, self.on)) @@ -936,8 +903,8 @@ class AsOfJoin(LogicalPlan): self, left: LogicalPlan, right: LogicalPlan, - left_as_of: "ColumnOrName", - right_as_of: "ColumnOrName", + left_as_of: Column, + right_as_of: Column, on: Optional[Union[str, List[str], Column, List[Column]]], how: str, tolerance: Optional[Column], @@ -960,19 +927,8 @@ class AsOfJoin(LogicalPlan): plan.as_of_join.left.CopyFrom(self.left.plan(session)) plan.as_of_join.right.CopyFrom(self.right.plan(session)) - if isinstance(self.left_as_of, Column): - plan.as_of_join.left_as_of.CopyFrom(self.left_as_of.to_plan(session)) - else: - plan.as_of_join.left_as_of.CopyFrom( - ColumnReference(self.left_as_of, self.left._plan_id).to_plan(session) - ) - - if isinstance(self.right_as_of, Column): - plan.as_of_join.right_as_of.CopyFrom(self.right_as_of.to_plan(session)) - else: - plan.as_of_join.right_as_of.CopyFrom( - ColumnReference(self.right_as_of, self.right._plan_id).to_plan(session) - ) + plan.as_of_join.left_as_of.CopyFrom(self.left_as_of.to_plan(session)) + plan.as_of_join.right_as_of.CopyFrom(self.right_as_of.to_plan(session)) if self.on is not None: if not isinstance(self.on, list): @@ -1128,26 +1084,18 @@ class RepartitionByExpression(LogicalPlan): self, child: Optional["LogicalPlan"], num_partitions: Optional[int], - columns: List["ColumnOrName"], + columns: List[Column], ) -> None: super().__init__(child) self.num_partitions = num_partitions + assert all(isinstance(c, Column) for c in columns) self.columns = columns def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() - - part_exprs = [] - for c in self.columns: - if isinstance(c, Column): - part_exprs.append(c.to_plan(session)) - elif c == "*": - exp = proto.Expression() - exp.unresolved_star.SetInParent() - part_exprs.append(exp) - else: - part_exprs.append(self.unresolved_attr(c)) - plan.repartition_by_expression.partition_exprs.extend(part_exprs) + plan.repartition_by_expression.partition_exprs.extend( + [c.to_plan(session) for c in self.columns] + ) if self._child is not None: plan.repartition_by_expression.input.CopyFrom(self._child.plan(session)) @@ -1283,8 +1231,8 @@ class Unpivot(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], - ids: List["ColumnOrName"], - values: Optional[List["ColumnOrName"]], + ids: List[Column], + values: Optional[List[Column]], variable_column_name: str, value_column_name: str, ) -> None: @@ -1294,19 +1242,13 @@ class Unpivot(LogicalPlan): self.variable_column_name = variable_column_name self.value_column_name = value_column_name - def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") -> proto.Expression: - if isinstance(col, Column): - return col.to_plan(session) - else: - return self.unresolved_attr(col) - def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.unpivot.input.CopyFrom(self._child.plan(session)) - plan.unpivot.ids.extend([self.col_to_expr(x, session) for x in self.ids]) + plan.unpivot.ids.extend([id.to_plan(session) for id in self.ids]) if self.values is not None: - plan.unpivot.values.values.extend([self.col_to_expr(x, session) for x in self.values]) + plan.unpivot.values.values.extend([v.to_plan(session) for v in self.values]) plan.unpivot.variable_column_name = self.variable_column_name plan.unpivot.value_column_name = self.value_column_name return plan @@ -1319,18 +1261,13 @@ class CollectMetrics(LogicalPlan): self, child: Optional["LogicalPlan"], observation: Union[str, "Observation"], - exprs: List["ColumnOrName"], + exprs: List[Column], ) -> None: super().__init__(child) self._observation = observation + assert all(isinstance(e, Column) for e in exprs) self._exprs = exprs - def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") -> proto.Expression: - if isinstance(col, Column): - return col.to_plan(session) - else: - return self.unresolved_attr(col) - def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() @@ -1340,7 +1277,7 @@ class CollectMetrics(LogicalPlan): if isinstance(self._observation, str) else str(self._observation._name) ) - plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for x in self._exprs]) + plan.collect_metrics.metrics.extend([e.to_plan(session) for e in self._exprs]) return plan @property @@ -1570,7 +1507,7 @@ class StatSampleBy(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], - col: "ColumnOrName", + col: Column, fractions: Dict[Any, float], seed: Optional[int], ) -> None: @@ -1584,13 +1521,8 @@ class StatSampleBy(LogicalPlan): assert seed is None or isinstance(seed, int) - if isinstance(col, Column): - self._col = col - else: - self._col = Column(ColumnReference(col)) - + self._col = col self._fractions = fractions - self._seed = seed def plan(self, session: "SparkConnectClient") -> proto.Relation: @@ -1767,17 +1699,11 @@ class WriteOperationV2(LogicalPlan): super(WriteOperationV2, self).__init__(child) self.table_name: Optional[str] = table_name self.provider: Optional[str] = None - self.partitioning_columns: List["ColumnOrName"] = [] + self.partitioning_columns: List[Column] = [] self.options: dict[str, Optional[str]] = {} self.table_properties: dict[str, Optional[str]] = {} self.mode: Optional[str] = None - self.overwrite_condition: Optional["ColumnOrName"] = None - - def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") -> proto.Expression: - if isinstance(col, Column): - return col.to_plan(session) - else: - return self.unresolved_attr(col) + self.overwrite_condition: Optional[Column] = None def command(self, session: "SparkConnectClient") -> proto.Command: assert self._child is not None @@ -1789,7 +1715,7 @@ class WriteOperationV2(LogicalPlan): plan.write_operation_v2.provider = self.provider plan.write_operation_v2.partitioning_columns.extend( - [self.col_to_expr(x, session) for x in self.partitioning_columns] + [c.to_plan(session) for c in self.partitioning_columns] ) for k in self.options: @@ -1818,7 +1744,7 @@ class WriteOperationV2(LogicalPlan): plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_REPLACE if self.overwrite_condition is not None: plan.write_operation_v2.overwrite_condition.CopyFrom( - self.col_to_expr(self.overwrite_condition, session) + self.overwrite_condition.to_plan(session) ) elif wm == "create_or_replace": plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 51698f262fc5..0e9c9128bdbf 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -31,6 +31,7 @@ from pyspark.sql.readwriter import ( DataFrameWriterV2 as PySparkDataFrameWriterV2, ) from pyspark.errors import PySparkAttributeError, PySparkTypeError, PySparkValueError +from pyspark.sql.connect.functions import builtin as F if TYPE_CHECKING: from pyspark.sql.connect.dataframe import DataFrame @@ -876,8 +877,7 @@ class DataFrameWriterV2(OptionUtils): tableProperty.__doc__ = PySparkDataFrameWriterV2.tableProperty.__doc__ def partitionedBy(self, col: "ColumnOrName", *cols: "ColumnOrName") -> "DataFrameWriterV2": - self._write.partitioning_columns = [col] - self._write.partitioning_columns.extend(cols) + self._write.partitioning_columns = [F._to_col(c) for c in [col] + list(cols)] return self partitionedBy.__doc__ = PySparkDataFrameWriterV2.partitionedBy.__doc__ @@ -916,7 +916,7 @@ class DataFrameWriterV2(OptionUtils): def overwrite(self, condition: "ColumnOrName") -> None: self._write.mode = "overwrite" - self._write.overwrite_condition = condition + self._write.overwrite_condition = F._to_col(condition) self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org