This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6c885a7cf57d [SPARK-45074][PYTHON][CONNECT] `DataFrame.{sort, sortWithinPartitions}` support column ordinals 6c885a7cf57d is described below commit 6c885a7cf57df328b03308cff2eed814bda156e4 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Mon Sep 4 23:31:23 2023 -0700 [SPARK-45074][PYTHON][CONNECT] `DataFrame.{sort, sortWithinPartitions}` support column ordinals ### What changes were proposed in this pull request? `DataFrame.{sort, sortWithinPartitions}` support column ordinals ### Why are the changes needed? for feature parity: SQL: ``` select a, 1, sum(b) from v group by 1, 2 order by 3, 1; ``` DataFrame: ``` df.select("a", sf.lit(1), "b").groupBy(1, 2).agg(sf.sum("b")).sort(3, 1) ``` ### Does this PR introduce _any_ user-facing change? yes, new feature ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? NO Closes #42809 from zhengruifeng/py_oderby_ordinal. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- python/pyspark/sql/connect/dataframe.py | 33 ++++++-- python/pyspark/sql/dataframe.py | 134 +++++++++++++++++++++++++++++--- python/pyspark/sql/tests/test_group.py | 53 +++++++++++++ 3 files changed, 202 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index b22fdc1383cf..c443023ce02a 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -593,7 +593,9 @@ class DataFrame: tail.__doc__ = PySparkDataFrame.tail.__doc__ def _sort_cols( - self, cols: Sequence[Union[str, Column, List[Union[str, Column]]]], kwargs: Dict[str, Any] + self, + cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]], + kwargs: Dict[str, Any], ) -> List[Column]: """Return a JVM Seq of Columns that describes the sort order""" if cols is None: @@ -602,11 +604,24 @@ class DataFrame: message_parameters={"item": "cols"}, ) - _cols: List[Column] = [] if len(cols) == 1 and isinstance(cols[0], list): - _cols = [_to_col(c) for c in cols[0]] - else: - _cols = [_to_col(cast("ColumnOrName", c)) for c in cols] + cols = cols[0] + + _cols: List[Column] = [] + for c in cols: + if isinstance(c, int) and not isinstance(c, bool): + # TODO: should introduce dedicated error class + # ordinal is 1-based + if c > 0: + _c = self[c - 1] + # negative ordinal means sort by desc + elif c < 0: + _c = self[-c - 1].desc() + else: + raise IndexError("Column ordinal must not be zero!") + else: + _c = c # type: ignore[assignment] + _cols.append(_to_col(cast("ColumnOrName", _c))) ascending = kwargs.get("ascending", True) if isinstance(ascending, (bool, int)): @@ -623,7 +638,9 @@ class DataFrame: return _cols def sort( - self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: Any + self, + *cols: Union[int, str, Column, List[Union[int, str, Column]]], + **kwargs: Any, ) -> "DataFrame": return DataFrame.withPlan( plan.Sort( @@ -639,7 +656,9 @@ class DataFrame: orderBy = sort def sortWithinPartitions( - self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: Any + self, + *cols: Union[int, str, Column, List[Union[int, str, Column]]], + **kwargs: Any, ) -> "DataFrame": return DataFrame.withPlan( plan.Sort( diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3d7bdd7a0b2b..f59ae40542b9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2853,7 +2853,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataFrame(jdf, self.sparkSession) def sortWithinPartitions( - self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: Any + self, + *cols: Union[int, str, Column, List[Union[int, str, Column]]], + **kwargs: Any, ) -> "DataFrame": """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s). @@ -2862,10 +2864,13 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.0.0 + Supports column ordinal. + Parameters ---------- - cols : str, list or :class:`Column`, optional - list of :class:`Column` or column names to sort by. + cols : int, str, list or :class:`Column`, optional + list of :class:`Column` or column names or column ordinals to sort by. Other Parameters ---------------- @@ -2879,17 +2884,42 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): :class:`DataFrame` DataFrame sorted by partitions. + Notes + ----- + A column ordinal starts from 1, which is different from the + 0-based :meth:`__getitem__`. + If a column ordinal is negative, it means sort descending. + Examples -------- + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) >>> df.sortWithinPartitions("age", ascending=False) DataFrame[age: bigint, name: string] + + >>> df.coalesce(1).sortWithinPartitions(1).show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + + >>> df.coalesce(1).sortWithinPartitions(-1).show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 2|Alice| + +---+-----+ """ jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sparkSession) def sort( - self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: Any + self, + *cols: Union[int, str, Column, List[Union[int, str, Column]]], + **kwargs: Any, ) -> "DataFrame": """Returns a new :class:`DataFrame` sorted by the specified column(s). @@ -2898,10 +2928,13 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.0.0 + Supports column ordinal. + Parameters ---------- - cols : str, list, or :class:`Column`, optional - list of :class:`Column` or column names to sort by. + cols : int, str, list, or :class:`Column`, optional + list of :class:`Column` or column names or column ordinals to sort by. Other Parameters ---------------- @@ -2915,15 +2948,29 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): :class:`DataFrame` Sorted DataFrame. + Notes + ----- + A column ordinal starts from 1, which is different from the + 0-based :meth:`__getitem__`. + If a column ordinal is negative, it means sort descending. + Examples -------- - >>> from pyspark.sql.functions import desc, asc + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([ ... (2, "Alice"), (5, "Bob")], schema=["age", "name"]) Sort the DataFrame in ascending order. - >>> df.sort(asc("age")).show() + >>> df.sort(sf.asc("age")).show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + + >>> df.sort(1).show() +---+-----+ |age| name| +---+-----+ @@ -2940,6 +2987,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): | 5| Bob| | 2|Alice| +---+-----+ + >>> df.orderBy(df.age.desc()).show() +---+-----+ |age| name| @@ -2947,6 +2995,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): | 5| Bob| | 2|Alice| +---+-----+ + >>> df.sort("age", ascending=False).show() +---+-----+ |age| name| @@ -2955,11 +3004,38 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): | 2|Alice| +---+-----+ + >>> df.sort(-1).show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 2|Alice| + +---+-----+ + Specify multiple columns + >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([ ... (2, "Alice"), (2, "Bob"), (5, "Bob")], schema=["age", "name"]) - >>> df.orderBy(desc("age"), "name").show() + >>> df.orderBy(sf.desc("age"), "name").show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 2|Alice| + | 2| Bob| + +---+-----+ + + >>> df.orderBy(-1, "name").show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 2|Alice| + | 2| Bob| + +---+-----+ + + >>> df.orderBy(-1, 2).show() +---+-----+ |age| name| +---+-----+ @@ -2978,6 +3054,24 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): | 2| Bob| | 2|Alice| +---+-----+ + + >>> df.orderBy([1, "name"], ascending=[False, False]).show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 2| Bob| + | 2|Alice| + +---+-----+ + + >>> df.orderBy([1, 2], ascending=[False, False]).show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 2| Bob| + | 2|Alice| + +---+-----+ """ jdf = self._jdf.sort(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sparkSession) @@ -3026,7 +3120,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return self._jseq(_cols, _to_java_column) def _sort_cols( - self, cols: Sequence[Union[str, Column, List[Union[str, Column]]]], kwargs: Dict[str, Any] + self, + cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]], + kwargs: Dict[str, Any], ) -> JavaObject: """Return a JVM Seq of Columns that describes the sort order""" if not cols: @@ -3036,7 +3132,23 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): ) if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] - jcols = [_to_java_column(cast("ColumnOrName", c)) for c in cols] + + jcols = [] + for c in cols: + if isinstance(c, int) and not isinstance(c, bool): + # TODO: should introduce dedicated error class + # ordinal is 1-based + if c > 0: + _c = self[c - 1] + # negative ordinal means sort by desc + elif c < 0: + _c = self[-c - 1].desc() + else: + raise IndexError("Column ordinal must not be zero!") + else: + _c = c # type: ignore[assignment] + jcols.append(_to_java_column(cast("ColumnOrName", _c))) + ascending = kwargs.get("ascending", True) if isinstance(ascending, (bool, int)): if not ascending: diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py index d481d725ebfb..6981601cb129 100644 --- a/python/pyspark/sql/tests/test_group.py +++ b/python/pyspark/sql/tests/test_group.py @@ -96,6 +96,59 @@ class GroupTestsMixin: with self.assertRaises(IndexError): df.groupBy(10).agg(sf.sum("b")) + def test_order_by_ordinal(self): + spark = self.spark + df = spark.createDataFrame( + [ + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2), + ], + ["a", "b"], + ) + + with self.tempView("v"): + df.createOrReplaceTempView("v") + + df1 = spark.sql("select * from v order by 1 desc;") + df2 = df.orderBy(-1) + assertSchemaEqual(df1.schema, df2.schema) + assertDataFrameEqual(df1, df2) + + df1 = spark.sql("select * from v order by 1 desc, b desc;") + df2 = df.orderBy(-1, df.b.desc()) + assertSchemaEqual(df1.schema, df2.schema) + assertDataFrameEqual(df1, df2) + + df1 = spark.sql("select * from v order by 1 desc, 2 desc;") + df2 = df.orderBy(-1, -2) + assertSchemaEqual(df1.schema, df2.schema) + assertDataFrameEqual(df1, df2) + + # groupby ordinal with orderby ordinal + df1 = spark.sql("select a, 1, sum(b) from v group by 1, 2 order by 1;") + df2 = df.select("a", sf.lit(1), "b").groupBy(1, 2).agg(sf.sum("b")).sort(1) + assertSchemaEqual(df1.schema, df2.schema) + assertDataFrameEqual(df1, df2) + + df1 = spark.sql("select a, 1, sum(b) from v group by 1, 2 order by 3, 1;") + df2 = df.select("a", sf.lit(1), "b").groupBy(1, 2).agg(sf.sum("b")).sort(3, 1) + assertSchemaEqual(df1.schema, df2.schema) + assertDataFrameEqual(df1, df2) + + # negative cases: ordinal out of range + with self.assertRaises(IndexError): + df.sort(0) + + with self.assertRaises(IndexError): + df.orderBy(3) + + with self.assertRaises(IndexError): + df.orderBy(-3) + class GroupTests(GroupTestsMixin, ReusedSQLTestCase): pass --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org