This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d17a8613a68 [SPARK-45047][PYTHON][CONNECT] `DataFrame.groupBy` support 
ordinals
d17a8613a68 is described below

commit d17a8613a68af076bc796881831382c29df4d90e
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Mon Sep 4 15:23:08 2023 -0700

    [SPARK-45047][PYTHON][CONNECT] `DataFrame.groupBy` support ordinals
    
    ### What changes were proposed in this pull request?
    
    make `DataFrame.groupBy` accept ordinals
    
    ### Why are the changes needed?
    
    for feature parity
    
    ```
    select target_country, ua_date, sum(spending_usd)
    from df
    group by 2, 1
    order by 2, 3 desc
    ```
    
    this PR focus on the `groupBy` method
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new feature
    
    ```
    In [2]: from pyspark.sql import functions as sf
    
    In [3]: df = spark.createDataFrame([(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), 
(3, 2)], ["a", "b"])
    
    In [4]: df.select("a", sf.lit(1), "b").groupBy("a", 
2).agg(sf.sum("b")).show()
    +---+---+------+
    |  a|  1|sum(b)|
    +---+---+------+
    |  1|  1|     3|
    |  2|  1|     3|
    |  3|  1|     3|
    +---+---+------+
    ```
    
    ### How was this patch tested?
    added ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #42767 from zhengruifeng/py_groupby_index.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 python/pyspark/sql/_typing.pyi                     |  1 +
 python/pyspark/sql/connect/_typing.py              |  2 +
 python/pyspark/sql/connect/dataframe.py            |  9 ++-
 python/pyspark/sql/dataframe.py                    | 66 ++++++++++++++++++++--
 python/pyspark/sql/tests/test_group.py             | 61 ++++++++++++++++++++
 python/pyspark/sql/tests/typing/test_dataframe.yml |  2 +-
 6 files changed, 133 insertions(+), 8 deletions(-)

diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi
index 3d095f55709..cee44c4aa06 100644
--- a/python/pyspark/sql/_typing.pyi
+++ b/python/pyspark/sql/_typing.pyi
@@ -36,6 +36,7 @@ from pyspark.sql.column import Column
 
 ColumnOrName = Union[Column, str]
 ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName)
+ColumnOrNameOrOrdinal = Union[Column, str, int]
 DecimalLiteral = decimal.Decimal
 DateTimeLiteral = Union[datetime.datetime, datetime.date]
 LiteralType = PrimitiveType
diff --git a/python/pyspark/sql/connect/_typing.py 
b/python/pyspark/sql/connect/_typing.py
index 4c76e37659c..471af24f40d 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -37,6 +37,8 @@ from pyspark.sql.streaming.state import GroupState
 
 ColumnOrName = Union[Column, str]
 
+ColumnOrNameOrOrdinal = Union[Column, str, int]
+
 PrimitiveType = Union[bool, float, int, str]
 
 OptionalPrimitiveType = Optional[PrimitiveType]
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index c42de589f8d..86a63536185 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -85,6 +85,7 @@ from pyspark.sql.pandas.types import from_arrow_schema
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import (
         ColumnOrName,
+        ColumnOrNameOrOrdinal,
         LiteralType,
         PrimitiveType,
         OptionalPrimitiveType,
@@ -476,7 +477,7 @@ class DataFrame:
 
     first.__doc__ = PySparkDataFrame.first.__doc__
 
-    def groupBy(self, *cols: "ColumnOrName") -> GroupedData:
+    def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> GroupedData:
         if len(cols) == 1 and isinstance(cols[0], list):
             cols = cols[0]
 
@@ -486,6 +487,12 @@ class DataFrame:
                 _cols.append(c)
             elif isinstance(c, str):
                 _cols.append(self[c])
+            elif isinstance(c, int) and not isinstance(c, bool):
+                # TODO: should introduce dedicated error class
+                if c < 1:
+                    raise IndexError(f"Column ordinal must be positive but got 
{c}")
+                # ordinal is 1-based
+                _cols.append(self[c - 1])
             else:
                 raise PySparkTypeError(
                     error_class="NOT_COLUMN_OR_STR",
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 64592311a13..4b8bdd1c277 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -67,7 +67,12 @@ from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
 if TYPE_CHECKING:
     from pyspark._typing import PrimitiveType
     from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
-    from pyspark.sql._typing import ColumnOrName, LiteralType, 
OptionalPrimitiveType
+    from pyspark.sql._typing import (
+        ColumnOrName,
+        ColumnOrNameOrOrdinal,
+        LiteralType,
+        OptionalPrimitiveType,
+    )
     from pyspark.sql.context import SQLContext
     from pyspark.sql.session import SparkSession
     from pyspark.sql.group import GroupedData
@@ -2919,6 +2924,26 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
             cols = cols[0]
         return self._jseq(cols, _to_java_column)
 
+    def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> JavaObject:
+        """Return a JVM Seq of Columns from a list of Column or column names 
or column ordinals.
+
+        If `cols` has only one list in it, cols[0] will be used as the list.
+        """
+        if len(cols) == 1 and isinstance(cols[0], list):
+            cols = cols[0]
+
+        _cols = []
+        for c in cols:
+            if isinstance(c, int) and not isinstance(c, bool):
+                # TODO: should introduce dedicated error class
+                if c < 1:
+                    raise IndexError(f"Column ordinal must be positive but got 
{c}")
+                # ordinal is 1-based
+                _cols.append(self[c - 1])
+            else:
+                _cols.append(c)  # type: ignore[arg-type]
+        return self._jseq(_cols, _to_java_column)
+
     def _sort_cols(
         self, cols: Sequence[Union[str, Column, List[Union[str, Column]]]], 
kwargs: Dict[str, Any]
     ) -> JavaObject:
@@ -3588,14 +3613,14 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         return DataFrame(jdf, self.sparkSession)
 
     @overload
-    def groupBy(self, *cols: "ColumnOrName") -> "GroupedData":
+    def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
         ...
 
     @overload
-    def groupBy(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+    def groupBy(self, __cols: Union[List[Column], List[str], List[int]]) -> 
"GroupedData":
         ...
 
-    def groupBy(self, *cols: "ColumnOrName") -> "GroupedData":  # type: 
ignore[misc]
+    def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":  # 
type: ignore[misc]
         """Groups the :class:`DataFrame` using the specified columns,
         so we can run aggregation on them. See :class:`GroupedData`
         for all the available aggregate functions.
@@ -3607,18 +3632,26 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         .. versionchanged:: 3.4.0
             Supports Spark Connect.
 
+        .. versionchanged:: 4.0.0
+            Supports column ordinal.
+
         Parameters
         ----------
         cols : list, str or :class:`Column`
             columns to group by.
             Each element should be a column name (string) or an expression 
(:class:`Column`)
-            or list of them.
+            or a column ordinal (int, 1-based) or list of them.
 
         Returns
         -------
         :class:`GroupedData`
             Grouped data by given columns.
 
+        Notes
+        -----
+        A column ordinal starts from 1, which is different from the
+        0-based :meth:`__getitem__`.
+
         Examples
         --------
         >>> df = spark.createDataFrame([
@@ -3653,6 +3686,16 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         |  Bob|       5|
         +-----+--------+
 
+        Also group-by 'name', but using the column ordinal.
+
+        >>> df.groupBy(2).max().sort("name").show()
+        +-----+--------+
+        | name|max(age)|
+        +-----+--------+
+        |Alice|       2|
+        |  Bob|       5|
+        +-----+--------+
+
         Group-by 'name' and 'age', and calculate the number of rows in each 
group.
 
         >>> df.groupBy(["name", df.age]).count().sort("name", "age").show()
@@ -3663,8 +3706,19 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         |  Bob|  2|    2|
         |  Bob|  5|    1|
         +-----+---+-----+
+
+        Also Group-by 'name' and 'age', but using the column ordinal.
+
+        >>> df.groupBy([df.name, 1]).count().sort("name", "age").show()
+        +-----+---+-----+
+        | name|age|count|
+        +-----+---+-----+
+        |Alice|  2|    1|
+        |  Bob|  2|    2|
+        |  Bob|  5|    1|
+        +-----+---+-----+
         """
-        jgd = self._jdf.groupBy(self._jcols(*cols))
+        jgd = self._jdf.groupBy(self._jcols_ordinal(*cols))
         from pyspark.sql.group import GroupedData
 
         return GroupedData(jgd, self)
diff --git a/python/pyspark/sql/tests/test_group.py 
b/python/pyspark/sql/tests/test_group.py
index 2715571a44d..d481d725ebf 100644
--- a/python/pyspark/sql/tests/test_group.py
+++ b/python/pyspark/sql/tests/test_group.py
@@ -16,7 +16,9 @@
 #
 
 from pyspark.sql import Row
+from pyspark.sql import functions as sf
 from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing import assertDataFrameEqual, assertSchemaEqual
 
 
 class GroupTestsMixin:
@@ -35,6 +37,65 @@ class GroupTestsMixin:
         # test deprecated countDistinct
         self.assertEqual(100, 
g.agg(functions.countDistinct(df.value)).first()[0])
 
+    def test_group_by_ordinal(self):
+        spark = self.spark
+        df = spark.createDataFrame(
+            [
+                (1, 1),
+                (1, 2),
+                (2, 1),
+                (2, 2),
+                (3, 1),
+                (3, 2),
+            ],
+            ["a", "b"],
+        )
+
+        with self.tempView("v"):
+            df.createOrReplaceTempView("v")
+
+            # basic case
+            df1 = spark.sql("select a, sum(b) from v group by 1;")
+            df2 = df.groupBy(1).agg(sf.sum("b"))
+            assertSchemaEqual(df1.schema, df2.schema)
+            assertDataFrameEqual(df1, df2)
+
+            # constant case
+            df1 = spark.sql("select 1, 2, sum(b) from v group by 1, 2;")
+            df2 = df.select(sf.lit(1), sf.lit(2), "b").groupBy(1, 
2).agg(sf.sum("b"))
+            assertSchemaEqual(df1.schema, df2.schema)
+            assertDataFrameEqual(df1, df2)
+
+            # duplicate group by column
+            df1 = spark.sql("select a, 1, sum(b) from v group by a, 1;")
+            df2 = df.select("a", sf.lit(1), "b").groupBy("a", 
2).agg(sf.sum("b"))
+            assertSchemaEqual(df1.schema, df2.schema)
+            assertDataFrameEqual(df1, df2)
+
+            df1 = spark.sql("select a, 1, sum(b) from v group by 1, 2;")
+            df2 = df.select("a", sf.lit(1), "b").groupBy(1, 2).agg(sf.sum("b"))
+            assertSchemaEqual(df1.schema, df2.schema)
+            assertDataFrameEqual(df1, df2)
+
+            # group by a non-aggregate expression's ordinal
+            df1 = spark.sql("select a, b + 2, count(2) from v group by a, 2;")
+            df2 = df.select("a", df.b + 2).groupBy(1, 
2).agg(sf.count(sf.lit(2)))
+            assertSchemaEqual(df1.schema, df2.schema)
+            assertDataFrameEqual(df1, df2)
+
+            # negative cases: ordinal out of range
+            with self.assertRaises(IndexError):
+                df.groupBy(0).agg(sf.sum("b"))
+
+            with self.assertRaises(IndexError):
+                df.groupBy(-1).agg(sf.sum("b"))
+
+            with self.assertRaises(IndexError):
+                df.groupBy(3).agg(sf.sum("b"))
+
+            with self.assertRaises(IndexError):
+                df.groupBy(10).agg(sf.sum("b"))
+
 
 class GroupTests(GroupTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/typing/test_dataframe.yml 
b/python/pyspark/sql/tests/typing/test_dataframe.yml
index d32a09cea82..7aa2f15cfa2 100644
--- a/python/pyspark/sql/tests/typing/test_dataframe.yml
+++ b/python/pyspark/sql/tests/typing/test_dataframe.yml
@@ -71,7 +71,7 @@
     df.groupby(["name", "age"])
     df.groupBy([col("name"), col("age")])
     df.groupby([col("name"), col("age")])
-    df.groupBy(["name", col("age")])  # E: Argument 1 to "groupBy" of 
"DataFrame" has incompatible type "List[object]"; expected "Union[List[Column], 
List[str]]"  [arg-type]
+    df.groupBy(["name", col("age")])  # E: Argument 1 to "groupBy" of 
"DataFrame" has incompatible type "List[object]"; expected "Union[List[Column], 
List[str], List[int]]"  [arg-type]
 
 
 - case: rollup


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to