This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d17a8613a68 [SPARK-45047][PYTHON][CONNECT] `DataFrame.groupBy` support ordinals d17a8613a68 is described below commit d17a8613a68af076bc796881831382c29df4d90e Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Mon Sep 4 15:23:08 2023 -0700 [SPARK-45047][PYTHON][CONNECT] `DataFrame.groupBy` support ordinals ### What changes were proposed in this pull request? make `DataFrame.groupBy` accept ordinals ### Why are the changes needed? for feature parity ``` select target_country, ua_date, sum(spending_usd) from df group by 2, 1 order by 2, 3 desc ``` this PR focus on the `groupBy` method ### Does this PR introduce _any_ user-facing change? yes, new feature ``` In [2]: from pyspark.sql import functions as sf In [3]: df = spark.createDataFrame([(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2)], ["a", "b"]) In [4]: df.select("a", sf.lit(1), "b").groupBy("a", 2).agg(sf.sum("b")).show() +---+---+------+ | a| 1|sum(b)| +---+---+------+ | 1| 1| 3| | 2| 1| 3| | 3| 1| 3| +---+---+------+ ``` ### How was this patch tested? added ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #42767 from zhengruifeng/py_groupby_index. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- python/pyspark/sql/_typing.pyi | 1 + python/pyspark/sql/connect/_typing.py | 2 + python/pyspark/sql/connect/dataframe.py | 9 ++- python/pyspark/sql/dataframe.py | 66 ++++++++++++++++++++-- python/pyspark/sql/tests/test_group.py | 61 ++++++++++++++++++++ python/pyspark/sql/tests/typing/test_dataframe.yml | 2 +- 6 files changed, 133 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi index 3d095f55709..cee44c4aa06 100644 --- a/python/pyspark/sql/_typing.pyi +++ b/python/pyspark/sql/_typing.pyi @@ -36,6 +36,7 @@ from pyspark.sql.column import Column ColumnOrName = Union[Column, str] ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName) +ColumnOrNameOrOrdinal = Union[Column, str, int] DecimalLiteral = decimal.Decimal DateTimeLiteral = Union[datetime.datetime, datetime.date] LiteralType = PrimitiveType diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py index 4c76e37659c..471af24f40d 100644 --- a/python/pyspark/sql/connect/_typing.py +++ b/python/pyspark/sql/connect/_typing.py @@ -37,6 +37,8 @@ from pyspark.sql.streaming.state import GroupState ColumnOrName = Union[Column, str] +ColumnOrNameOrOrdinal = Union[Column, str, int] + PrimitiveType = Union[bool, float, int, str] OptionalPrimitiveType = Optional[PrimitiveType] diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index c42de589f8d..86a63536185 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -85,6 +85,7 @@ from pyspark.sql.pandas.types import from_arrow_schema if TYPE_CHECKING: from pyspark.sql.connect._typing import ( ColumnOrName, + ColumnOrNameOrOrdinal, LiteralType, PrimitiveType, OptionalPrimitiveType, @@ -476,7 +477,7 @@ class DataFrame: first.__doc__ = PySparkDataFrame.first.__doc__ - def groupBy(self, *cols: "ColumnOrName") -> GroupedData: + def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> GroupedData: if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] @@ -486,6 +487,12 @@ class DataFrame: _cols.append(c) elif isinstance(c, str): _cols.append(self[c]) + elif isinstance(c, int) and not isinstance(c, bool): + # TODO: should introduce dedicated error class + if c < 1: + raise IndexError(f"Column ordinal must be positive but got {c}") + # ordinal is 1-based + _cols.append(self[c - 1]) else: raise PySparkTypeError( error_class="NOT_COLUMN_OR_STR", diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 64592311a13..4b8bdd1c277 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -67,7 +67,12 @@ from pyspark.sql.pandas.map_ops import PandasMapOpsMixin if TYPE_CHECKING: from pyspark._typing import PrimitiveType from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame - from pyspark.sql._typing import ColumnOrName, LiteralType, OptionalPrimitiveType + from pyspark.sql._typing import ( + ColumnOrName, + ColumnOrNameOrOrdinal, + LiteralType, + OptionalPrimitiveType, + ) from pyspark.sql.context import SQLContext from pyspark.sql.session import SparkSession from pyspark.sql.group import GroupedData @@ -2919,6 +2924,26 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): cols = cols[0] return self._jseq(cols, _to_java_column) + def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> JavaObject: + """Return a JVM Seq of Columns from a list of Column or column names or column ordinals. + + If `cols` has only one list in it, cols[0] will be used as the list. + """ + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + + _cols = [] + for c in cols: + if isinstance(c, int) and not isinstance(c, bool): + # TODO: should introduce dedicated error class + if c < 1: + raise IndexError(f"Column ordinal must be positive but got {c}") + # ordinal is 1-based + _cols.append(self[c - 1]) + else: + _cols.append(c) # type: ignore[arg-type] + return self._jseq(_cols, _to_java_column) + def _sort_cols( self, cols: Sequence[Union[str, Column, List[Union[str, Column]]]], kwargs: Dict[str, Any] ) -> JavaObject: @@ -3588,14 +3613,14 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataFrame(jdf, self.sparkSession) @overload - def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": + def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": ... @overload - def groupBy(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": + def groupBy(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData": ... - def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] + def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc] """Groups the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedData` for all the available aggregate functions. @@ -3607,18 +3632,26 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.0.0 + Supports column ordinal. + Parameters ---------- cols : list, str or :class:`Column` columns to group by. Each element should be a column name (string) or an expression (:class:`Column`) - or list of them. + or a column ordinal (int, 1-based) or list of them. Returns ------- :class:`GroupedData` Grouped data by given columns. + Notes + ----- + A column ordinal starts from 1, which is different from the + 0-based :meth:`__getitem__`. + Examples -------- >>> df = spark.createDataFrame([ @@ -3653,6 +3686,16 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): | Bob| 5| +-----+--------+ + Also group-by 'name', but using the column ordinal. + + >>> df.groupBy(2).max().sort("name").show() + +-----+--------+ + | name|max(age)| + +-----+--------+ + |Alice| 2| + | Bob| 5| + +-----+--------+ + Group-by 'name' and 'age', and calculate the number of rows in each group. >>> df.groupBy(["name", df.age]).count().sort("name", "age").show() @@ -3663,8 +3706,19 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): | Bob| 2| 2| | Bob| 5| 1| +-----+---+-----+ + + Also Group-by 'name' and 'age', but using the column ordinal. + + >>> df.groupBy([df.name, 1]).count().sort("name", "age").show() + +-----+---+-----+ + | name|age|count| + +-----+---+-----+ + |Alice| 2| 1| + | Bob| 2| 2| + | Bob| 5| 1| + +-----+---+-----+ """ - jgd = self._jdf.groupBy(self._jcols(*cols)) + jgd = self._jdf.groupBy(self._jcols_ordinal(*cols)) from pyspark.sql.group import GroupedData return GroupedData(jgd, self) diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py index 2715571a44d..d481d725ebf 100644 --- a/python/pyspark/sql/tests/test_group.py +++ b/python/pyspark/sql/tests/test_group.py @@ -16,7 +16,9 @@ # from pyspark.sql import Row +from pyspark.sql import functions as sf from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing import assertDataFrameEqual, assertSchemaEqual class GroupTestsMixin: @@ -35,6 +37,65 @@ class GroupTestsMixin: # test deprecated countDistinct self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_group_by_ordinal(self): + spark = self.spark + df = spark.createDataFrame( + [ + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2), + ], + ["a", "b"], + ) + + with self.tempView("v"): + df.createOrReplaceTempView("v") + + # basic case + df1 = spark.sql("select a, sum(b) from v group by 1;") + df2 = df.groupBy(1).agg(sf.sum("b")) + assertSchemaEqual(df1.schema, df2.schema) + assertDataFrameEqual(df1, df2) + + # constant case + df1 = spark.sql("select 1, 2, sum(b) from v group by 1, 2;") + df2 = df.select(sf.lit(1), sf.lit(2), "b").groupBy(1, 2).agg(sf.sum("b")) + assertSchemaEqual(df1.schema, df2.schema) + assertDataFrameEqual(df1, df2) + + # duplicate group by column + df1 = spark.sql("select a, 1, sum(b) from v group by a, 1;") + df2 = df.select("a", sf.lit(1), "b").groupBy("a", 2).agg(sf.sum("b")) + assertSchemaEqual(df1.schema, df2.schema) + assertDataFrameEqual(df1, df2) + + df1 = spark.sql("select a, 1, sum(b) from v group by 1, 2;") + df2 = df.select("a", sf.lit(1), "b").groupBy(1, 2).agg(sf.sum("b")) + assertSchemaEqual(df1.schema, df2.schema) + assertDataFrameEqual(df1, df2) + + # group by a non-aggregate expression's ordinal + df1 = spark.sql("select a, b + 2, count(2) from v group by a, 2;") + df2 = df.select("a", df.b + 2).groupBy(1, 2).agg(sf.count(sf.lit(2))) + assertSchemaEqual(df1.schema, df2.schema) + assertDataFrameEqual(df1, df2) + + # negative cases: ordinal out of range + with self.assertRaises(IndexError): + df.groupBy(0).agg(sf.sum("b")) + + with self.assertRaises(IndexError): + df.groupBy(-1).agg(sf.sum("b")) + + with self.assertRaises(IndexError): + df.groupBy(3).agg(sf.sum("b")) + + with self.assertRaises(IndexError): + df.groupBy(10).agg(sf.sum("b")) + class GroupTests(GroupTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/typing/test_dataframe.yml b/python/pyspark/sql/tests/typing/test_dataframe.yml index d32a09cea82..7aa2f15cfa2 100644 --- a/python/pyspark/sql/tests/typing/test_dataframe.yml +++ b/python/pyspark/sql/tests/typing/test_dataframe.yml @@ -71,7 +71,7 @@ df.groupby(["name", "age"]) df.groupBy([col("name"), col("age")]) df.groupby([col("name"), col("age")]) - df.groupBy(["name", col("age")]) # E: Argument 1 to "groupBy" of "DataFrame" has incompatible type "List[object]"; expected "Union[List[Column], List[str]]" [arg-type] + df.groupBy(["name", col("age")]) # E: Argument 1 to "groupBy" of "DataFrame" has incompatible type "List[object]"; expected "Union[List[Column], List[str], List[int]]" [arg-type] - case: rollup --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org