This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0eb96ae6eb86 [SPARK-47620][PYTHON][CONNECT] Add a helper function to sort columns 0eb96ae6eb86 is described below commit 0eb96ae6eb8680155d4c6974dadaeebd7475a1fc Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Mar 28 12:22:27 2024 +0800 [SPARK-47620][PYTHON][CONNECT] Add a helper function to sort columns ### What changes were proposed in this pull request? Add a helper function `_sort_col` to sort columns ### Why are the changes needed? simple code refactoring ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #45743 from zhengruifeng/connect_sort_col. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/dataframe.py | 16 +++------------- python/pyspark/sql/connect/functions/builtin.py | 12 ++++++++++++ python/pyspark/sql/connect/plan.py | 15 ++------------- 3 files changed, 17 insertions(+), 26 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 82929ccfbc4f..672ac8b9c25c 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -72,7 +72,6 @@ from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2 from pyspark.sql.connect.streaming.readwriter import DataStreamWriter from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import ( - SortOrder, ColumnReference, UnresolvedRegex, UnresolvedStar, @@ -349,15 +348,6 @@ class DataFrame: def repartitionByRange( # type: ignore[misc] self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName" ) -> "DataFrame": - def _convert_col(col: "ColumnOrName") -> Column: - if isinstance(col, Column): - if isinstance(col._expr, SortOrder): - return col - else: - return col.asc() - else: - return F.col(col).asc() - if isinstance(numPartitions, int): if not numPartitions > 0: raise PySparkValueError( @@ -375,14 +365,14 @@ class DataFrame: else: return DataFrame( plan.RepartitionByExpression( - self._plan, numPartitions, [_convert_col(c) for c in cols] + self._plan, numPartitions, [F._sort_col(c) for c in cols] ), self.sparkSession, ) elif isinstance(numPartitions, (str, Column)): return DataFrame( plan.RepartitionByExpression( - self._plan, None, [_convert_col(c) for c in [numPartitions] + list(cols)] + self._plan, None, [F._sort_col(c) for c in [numPartitions] + list(cols)] ), self.sparkSession, ) @@ -729,7 +719,7 @@ class DataFrame: message_parameters={"arg_name": "ascending", "arg_type": type(ascending).__name__}, ) - return _cols + return [F._sort_col(c) for c in _cols] def sort( self, diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index c423c5f188ef..57c1f881eebf 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -45,6 +45,7 @@ from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import ( CaseWhen, + SortOrder, Expression, LiteralExpression, ColumnReference, @@ -88,6 +89,17 @@ def _to_col(col: "ColumnOrName") -> Column: return col if isinstance(col, Column) else column(col) +def _sort_col(col: "ColumnOrName") -> Column: + assert isinstance(col, (Column, str)) + if isinstance(col, Column): + if isinstance(col._expr, SortOrder): + return col + else: + return col.asc() + else: + return column(col).asc() + + def _invoke_function(name: str, *args: Union[Column, Expression]) -> Column: """ Simple wrapper function that converts the arguments into the appropriate types. diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index d9dd8874398c..863c27fabf6b 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -46,10 +46,7 @@ from pyspark.sql.types import DataType import pyspark.sql.connect.proto as proto from pyspark.sql.connect.conversion import storage_level_to_proto from pyspark.sql.connect.column import Column -from pyspark.sql.connect.expressions import ( - Expression, - SortOrder, -) +from pyspark.sql.connect.expressions import Expression from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType from pyspark.errors import ( PySparkValueError, @@ -674,19 +671,11 @@ class Sort(LogicalPlan): self.columns = columns self.is_global = is_global - def _convert_col( - self, col: Column, session: "SparkConnectClient" - ) -> proto.Expression.SortOrder: - if isinstance(col._expr, SortOrder): - return col._expr.to_plan(session).sort_order - else: - return SortOrder(col._expr).to_plan(session).sort_order - def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.sort.input.CopyFrom(self._child.plan(session)) - plan.sort.order.extend([self._convert_col(c, session) for c in self.columns]) + plan.sort.order.extend([c.to_plan(session).sort_order for c in self.columns]) plan.sort.is_global = self.is_global return plan --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org