This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d14410c6777 [SPARK-46048][PYTHON][CONNECT] Support 
DataFrame.groupingSets in Python Spark Connect
d14410c6777 is described below

commit d14410c6777e7de7f61e1957fab749da2793f4b8
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Thu Nov 23 16:38:52 2023 +0900

    [SPARK-46048][PYTHON][CONNECT] Support DataFrame.groupingSets in Python 
Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR adds `DataFrame.groupingSets` in Python Spark Connect.
    
    ### Why are the changes needed?
    
    For feature parity with non-Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it adds the new API `DataFframe.groupingSets` in Python Spark Connect.
    
    ### How was this patch tested?
    
    Unittests were added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43967 from HyukjinKwon/SPARK-46048.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |   9 +
 .../org/apache/spark/sql/connect/dsl/package.scala |  21 +++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  11 ++
 .../connect/planner/SparkConnectProtoSuite.scala   |  12 ++
 python/pyspark/sql/connect/dataframe.py            |  39 +++++
 python/pyspark/sql/connect/group.py                |  16 +-
 python/pyspark/sql/connect/plan.py                 |  23 ++-
 python/pyspark/sql/connect/proto/relations_pb2.py  | 194 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  36 ++++
 python/pyspark/sql/dataframe.py                    |   1 -
 10 files changed, 262 insertions(+), 100 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index deb33978386..43f692671df 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -327,12 +327,16 @@ message Aggregate {
   // (Optional) Pivots a column of the current `DataFrame` and performs the 
specified aggregation.
   Pivot pivot = 5;
 
+  // (Optional) List of values that will be translated to columns in the 
output DataFrame.
+  repeated GroupingSets grouping_sets = 6;
+
   enum GroupType {
     GROUP_TYPE_UNSPECIFIED = 0;
     GROUP_TYPE_GROUPBY = 1;
     GROUP_TYPE_ROLLUP = 2;
     GROUP_TYPE_CUBE = 3;
     GROUP_TYPE_PIVOT = 4;
+    GROUP_TYPE_GROUPING_SETS = 5;
   }
 
   message Pivot {
@@ -345,6 +349,11 @@ message Aggregate {
     // the distinct values of the column.
     repeated Expression.Literal values = 2;
   }
+
+  message GroupingSets {
+    // (Required) Individual grouping set
+    repeated Expression grouping_set = 1;
+  }
 }
 
 // Relation of type [[Sort]].
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 5fd1a035385..18c71ae4ace 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -800,6 +800,27 @@ package object dsl {
         Relation.newBuilder().setAggregate(agg.build()).build()
       }
 
+      def groupingSets(groupingSets: Seq[Seq[Expression]], groupingExprs: 
Expression*)(
+          aggregateExprs: Expression*): Relation = {
+        val agg = Aggregate.newBuilder()
+        agg.setInput(logicalPlan)
+        agg.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS)
+        for (groupingSet <- groupingSets) {
+          val groupingSetMsg = Aggregate.GroupingSets.newBuilder()
+          for (groupCol <- groupingSet) {
+            groupingSetMsg.addGroupingSet(groupCol)
+          }
+          agg.addGroupingSets(groupingSetMsg)
+        }
+        for (groupingExpr <- groupingExprs) {
+          agg.addGroupingExpressions(groupingExpr)
+        }
+        for (aggregateExpr <- aggregateExprs) {
+          agg.addAggregateExpressions(aggregateExpr)
+        }
+        Relation.newBuilder().setAggregate(agg.build()).build()
+      }
+
       def except(otherPlan: Relation, isAll: Boolean): Relation = {
         Relation
           .newBuilder()
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 4a0aa7e5589..95c5acc803d 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2445,6 +2445,17 @@ class SparkConnectPlanner(
           aggregates = aggExprs,
           child = input)
 
+      case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS =>
+        val groupingSetsExprs = rel.getGroupingSetsList.asScala.toSeq.map { 
getGroupingSets =>
+          
getGroupingSets.getGroupingSetList.asScala.toSeq.map(transformExpression)
+        }
+        logical.Aggregate(
+          groupingExpressions = Seq(
+            GroupingSets(
+              groupingSets = groupingSetsExprs,
+              userGivenGroupByExprs = groupingExprs)),
+          aggregateExpressions = aliasedAgg,
+          child = input)
       case other => throw InvalidPlanInput(s"Unknown Group Type $other")
     }
   }
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index c54aa496c66..0b27ccdbef8 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -307,6 +307,18 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
     comparePlans(connectPlan2, sparkPlan2)
   }
 
+  test("GroupingSets expressions") {
+    val connectPlan1 =
+      connectTestRelation.groupingSets(Seq(Seq("id".protoAttr), Seq.empty), 
"id".protoAttr)(
+        
proto_min(proto.Expression.newBuilder().setLiteral(toLiteralProto(1)).build())
+          .as("agg1"))
+    val sparkPlan1 =
+      sparkTestRelation
+        .groupingSets(Seq(Seq(Column("id")), Seq.empty), Column("id"))
+        .agg(min(lit(1)).as("agg1"))
+    comparePlans(connectPlan1, sparkPlan1)
+  }
+
   test("Test as(alias: String)") {
     val connectPlan = connectTestRelation.as("target_table")
     val sparkPlan = sparkTestRelation.as("target_table")
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index c7b51205363..b3bec44428b 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -550,6 +550,45 @@ class DataFrame:
 
     cube.__doc__ = PySparkDataFrame.cube.__doc__
 
+    def groupingSets(
+        self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: 
"ColumnOrName"
+    ) -> "GroupedData":
+        gsets: List[List[Column]] = []
+        for grouping_set in groupingSets:
+            gset: List[Column] = []
+            for c in grouping_set:
+                if isinstance(c, Column):
+                    gset.append(c)
+                elif isinstance(c, str):
+                    gset.append(self[c])
+                else:
+                    raise PySparkTypeError(
+                        error_class="NOT_COLUMN_OR_STR",
+                        message_parameters={
+                            "arg_name": "groupingSets",
+                            "arg_type": type(c).__name__,
+                        },
+                    )
+            gsets.append(gset)
+
+        gcols: List[Column] = []
+        for c in cols:
+            if isinstance(c, Column):
+                gcols.append(c)
+            elif isinstance(c, str):
+                gcols.append(self[c])
+            else:
+                raise PySparkTypeError(
+                    error_class="NOT_COLUMN_OR_STR",
+                    message_parameters={"arg_name": "cols", "arg_type": 
type(c).__name__},
+                )
+
+        return GroupedData(
+            df=self, group_type="grouping_sets", grouping_cols=gcols, 
grouping_sets=gsets
+        )
+
+    groupingSets.__doc__ = PySparkDataFrame.groupingSets.__doc__
+
     @overload
     def head(self) -> Optional[Row]:
         ...
diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index 7b71a43c112..481b7981a15 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -63,13 +63,20 @@ class GroupedData:
         grouping_cols: Sequence["Column"],
         pivot_col: Optional["Column"] = None,
         pivot_values: Optional[Sequence["LiteralType"]] = None,
+        grouping_sets: Optional[Sequence[Sequence["Column"]]] = None,
     ) -> None:
         from pyspark.sql.connect.dataframe import DataFrame
 
         assert isinstance(df, DataFrame)
         self._df = df
 
-        assert isinstance(group_type, str) and group_type in ["groupby", 
"rollup", "cube", "pivot"]
+        assert isinstance(group_type, str) and group_type in [
+            "groupby",
+            "rollup",
+            "cube",
+            "pivot",
+            "grouping_sets",
+        ]
         self._group_type = group_type
 
         assert isinstance(grouping_cols, list) and all(isinstance(g, Column) 
for g in grouping_cols)
@@ -83,6 +90,11 @@ class GroupedData:
             self._pivot_col = pivot_col
             self._pivot_values = pivot_values
 
+        self._grouping_sets: Optional[Sequence[Sequence["Column"]]] = None
+        if group_type == "grouping_sets":
+            assert grouping_sets is None or isinstance(grouping_sets, list)
+            self._grouping_sets = grouping_sets
+
     def __repr__(self) -> str:
         # the expressions are not resolved here,
         # so the string representation can be different from vanilla PySpark.
@@ -130,6 +142,7 @@ class GroupedData:
                 aggregate_cols=aggregate_cols,
                 pivot_col=self._pivot_col,
                 pivot_values=self._pivot_values,
+                grouping_sets=self._grouping_sets,
             ),
             session=self._df._session,
         )
@@ -171,6 +184,7 @@ class GroupedData:
                 aggregate_cols=[_invoke_function(function, col(c)) for c in 
agg_cols],
                 pivot_col=self._pivot_col,
                 pivot_values=self._pivot_values,
+                grouping_sets=self._grouping_sets,
             ),
             session=self._df._session,
         )
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 607d1429a9e..7d63f8714a9 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -778,10 +778,17 @@ class Aggregate(LogicalPlan):
         aggregate_cols: Sequence[Column],
         pivot_col: Optional[Column],
         pivot_values: Optional[Sequence[Any]],
+        grouping_sets: Optional[Sequence[Sequence[Column]]],
     ) -> None:
         super().__init__(child)
 
-        assert isinstance(group_type, str) and group_type in ["groupby", 
"rollup", "cube", "pivot"]
+        assert isinstance(group_type, str) and group_type in [
+            "groupby",
+            "rollup",
+            "cube",
+            "pivot",
+            "grouping_sets",
+        ]
         self._group_type = group_type
 
         assert isinstance(grouping_cols, list) and all(isinstance(c, Column) 
for c in grouping_cols)
@@ -795,12 +802,16 @@ class Aggregate(LogicalPlan):
         if group_type == "pivot":
             assert pivot_col is not None and isinstance(pivot_col, Column)
             assert pivot_values is None or isinstance(pivot_values, list)
+        elif group_type == "grouping_sets":
+            assert grouping_sets is None or isinstance(grouping_sets, list)
         else:
             assert pivot_col is None
             assert pivot_values is None
+            assert grouping_sets is None
 
         self._pivot_col = pivot_col
         self._pivot_values = pivot_values
+        self._grouping_sets = grouping_sets
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         from pyspark.sql.connect.functions import lit
@@ -829,7 +840,15 @@ class Aggregate(LogicalPlan):
                 plan.aggregate.pivot.values.extend(
                     [lit(v).to_plan(session).literal for v in 
self._pivot_values]
                 )
-
+        elif self._group_type == "grouping_sets":
+            plan.aggregate.group_type = 
proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS
+            assert self._grouping_sets is not None
+            for grouping_set in self._grouping_sets:
+                plan.aggregate.grouping_sets.append(
+                    proto.Aggregate.GroupingSets(
+                        grouping_set=[c.to_plan(session) for c in grouping_set]
+                    )
+                )
         return plan
 
 
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index fc70cdea402..f79ee786afb 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\x9a\x19\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\x9a\x19\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -104,101 +104,103 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _TAIL._serialized_start = 6182
     _TAIL._serialized_end = 6257
     _AGGREGATE._serialized_start = 6260
-    _AGGREGATE._serialized_end = 6842
-    _AGGREGATE_PIVOT._serialized_start = 6599
-    _AGGREGATE_PIVOT._serialized_end = 6710
-    _AGGREGATE_GROUPTYPE._serialized_start = 6713
-    _AGGREGATE_GROUPTYPE._serialized_end = 6842
-    _SORT._serialized_start = 6845
-    _SORT._serialized_end = 7005
-    _DROP._serialized_start = 7008
-    _DROP._serialized_end = 7149
-    _DEDUPLICATE._serialized_start = 7152
-    _DEDUPLICATE._serialized_end = 7392
-    _LOCALRELATION._serialized_start = 7394
-    _LOCALRELATION._serialized_end = 7483
-    _CACHEDLOCALRELATION._serialized_start = 7485
-    _CACHEDLOCALRELATION._serialized_end = 7557
-    _CACHEDREMOTERELATION._serialized_start = 7559
-    _CACHEDREMOTERELATION._serialized_end = 7614
-    _SAMPLE._serialized_start = 7617
-    _SAMPLE._serialized_end = 7890
-    _RANGE._serialized_start = 7893
-    _RANGE._serialized_end = 8038
-    _SUBQUERYALIAS._serialized_start = 8040
-    _SUBQUERYALIAS._serialized_end = 8154
-    _REPARTITION._serialized_start = 8157
-    _REPARTITION._serialized_end = 8299
-    _SHOWSTRING._serialized_start = 8302
-    _SHOWSTRING._serialized_end = 8444
-    _HTMLSTRING._serialized_start = 8446
-    _HTMLSTRING._serialized_end = 8560
-    _STATSUMMARY._serialized_start = 8562
-    _STATSUMMARY._serialized_end = 8654
-    _STATDESCRIBE._serialized_start = 8656
-    _STATDESCRIBE._serialized_end = 8737
-    _STATCROSSTAB._serialized_start = 8739
-    _STATCROSSTAB._serialized_end = 8840
-    _STATCOV._serialized_start = 8842
-    _STATCOV._serialized_end = 8938
-    _STATCORR._serialized_start = 8941
-    _STATCORR._serialized_end = 9078
-    _STATAPPROXQUANTILE._serialized_start = 9081
-    _STATAPPROXQUANTILE._serialized_end = 9245
-    _STATFREQITEMS._serialized_start = 9247
-    _STATFREQITEMS._serialized_end = 9372
-    _STATSAMPLEBY._serialized_start = 9375
-    _STATSAMPLEBY._serialized_end = 9684
-    _STATSAMPLEBY_FRACTION._serialized_start = 9576
-    _STATSAMPLEBY_FRACTION._serialized_end = 9675
-    _NAFILL._serialized_start = 9687
-    _NAFILL._serialized_end = 9821
-    _NADROP._serialized_start = 9824
-    _NADROP._serialized_end = 9958
-    _NAREPLACE._serialized_start = 9961
-    _NAREPLACE._serialized_end = 10257
-    _NAREPLACE_REPLACEMENT._serialized_start = 10116
-    _NAREPLACE_REPLACEMENT._serialized_end = 10257
-    _TODF._serialized_start = 10259
-    _TODF._serialized_end = 10347
-    _WITHCOLUMNSRENAMED._serialized_start = 10350
-    _WITHCOLUMNSRENAMED._serialized_end = 10589
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10522
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10589
-    _WITHCOLUMNS._serialized_start = 10591
-    _WITHCOLUMNS._serialized_end = 10710
-    _WITHWATERMARK._serialized_start = 10713
-    _WITHWATERMARK._serialized_end = 10847
-    _HINT._serialized_start = 10850
-    _HINT._serialized_end = 10982
-    _UNPIVOT._serialized_start = 10985
-    _UNPIVOT._serialized_end = 11312
-    _UNPIVOT_VALUES._serialized_start = 11242
-    _UNPIVOT_VALUES._serialized_end = 11301
-    _TOSCHEMA._serialized_start = 11314
-    _TOSCHEMA._serialized_end = 11420
-    _REPARTITIONBYEXPRESSION._serialized_start = 11423
-    _REPARTITIONBYEXPRESSION._serialized_end = 11626
-    _MAPPARTITIONS._serialized_start = 11629
-    _MAPPARTITIONS._serialized_end = 11810
-    _GROUPMAP._serialized_start = 11813
-    _GROUPMAP._serialized_end = 12448
-    _COGROUPMAP._serialized_start = 12451
-    _COGROUPMAP._serialized_end = 12977
-    _APPLYINPANDASWITHSTATE._serialized_start = 12980
-    _APPLYINPANDASWITHSTATE._serialized_end = 13337
-    _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 13340
-    _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 13584
-    _PYTHONUDTF._serialized_start = 13587
-    _PYTHONUDTF._serialized_end = 13764
-    _COLLECTMETRICS._serialized_start = 13767
-    _COLLECTMETRICS._serialized_end = 13903
-    _PARSE._serialized_start = 13906
-    _PARSE._serialized_end = 14294
+    _AGGREGATE._serialized_end = 7026
+    _AGGREGATE_PIVOT._serialized_start = 6675
+    _AGGREGATE_PIVOT._serialized_end = 6786
+    _AGGREGATE_GROUPINGSETS._serialized_start = 6788
+    _AGGREGATE_GROUPINGSETS._serialized_end = 6864
+    _AGGREGATE_GROUPTYPE._serialized_start = 6867
+    _AGGREGATE_GROUPTYPE._serialized_end = 7026
+    _SORT._serialized_start = 7029
+    _SORT._serialized_end = 7189
+    _DROP._serialized_start = 7192
+    _DROP._serialized_end = 7333
+    _DEDUPLICATE._serialized_start = 7336
+    _DEDUPLICATE._serialized_end = 7576
+    _LOCALRELATION._serialized_start = 7578
+    _LOCALRELATION._serialized_end = 7667
+    _CACHEDLOCALRELATION._serialized_start = 7669
+    _CACHEDLOCALRELATION._serialized_end = 7741
+    _CACHEDREMOTERELATION._serialized_start = 7743
+    _CACHEDREMOTERELATION._serialized_end = 7798
+    _SAMPLE._serialized_start = 7801
+    _SAMPLE._serialized_end = 8074
+    _RANGE._serialized_start = 8077
+    _RANGE._serialized_end = 8222
+    _SUBQUERYALIAS._serialized_start = 8224
+    _SUBQUERYALIAS._serialized_end = 8338
+    _REPARTITION._serialized_start = 8341
+    _REPARTITION._serialized_end = 8483
+    _SHOWSTRING._serialized_start = 8486
+    _SHOWSTRING._serialized_end = 8628
+    _HTMLSTRING._serialized_start = 8630
+    _HTMLSTRING._serialized_end = 8744
+    _STATSUMMARY._serialized_start = 8746
+    _STATSUMMARY._serialized_end = 8838
+    _STATDESCRIBE._serialized_start = 8840
+    _STATDESCRIBE._serialized_end = 8921
+    _STATCROSSTAB._serialized_start = 8923
+    _STATCROSSTAB._serialized_end = 9024
+    _STATCOV._serialized_start = 9026
+    _STATCOV._serialized_end = 9122
+    _STATCORR._serialized_start = 9125
+    _STATCORR._serialized_end = 9262
+    _STATAPPROXQUANTILE._serialized_start = 9265
+    _STATAPPROXQUANTILE._serialized_end = 9429
+    _STATFREQITEMS._serialized_start = 9431
+    _STATFREQITEMS._serialized_end = 9556
+    _STATSAMPLEBY._serialized_start = 9559
+    _STATSAMPLEBY._serialized_end = 9868
+    _STATSAMPLEBY_FRACTION._serialized_start = 9760
+    _STATSAMPLEBY_FRACTION._serialized_end = 9859
+    _NAFILL._serialized_start = 9871
+    _NAFILL._serialized_end = 10005
+    _NADROP._serialized_start = 10008
+    _NADROP._serialized_end = 10142
+    _NAREPLACE._serialized_start = 10145
+    _NAREPLACE._serialized_end = 10441
+    _NAREPLACE_REPLACEMENT._serialized_start = 10300
+    _NAREPLACE_REPLACEMENT._serialized_end = 10441
+    _TODF._serialized_start = 10443
+    _TODF._serialized_end = 10531
+    _WITHCOLUMNSRENAMED._serialized_start = 10534
+    _WITHCOLUMNSRENAMED._serialized_end = 10773
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10706
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10773
+    _WITHCOLUMNS._serialized_start = 10775
+    _WITHCOLUMNS._serialized_end = 10894
+    _WITHWATERMARK._serialized_start = 10897
+    _WITHWATERMARK._serialized_end = 11031
+    _HINT._serialized_start = 11034
+    _HINT._serialized_end = 11166
+    _UNPIVOT._serialized_start = 11169
+    _UNPIVOT._serialized_end = 11496
+    _UNPIVOT_VALUES._serialized_start = 11426
+    _UNPIVOT_VALUES._serialized_end = 11485
+    _TOSCHEMA._serialized_start = 11498
+    _TOSCHEMA._serialized_end = 11604
+    _REPARTITIONBYEXPRESSION._serialized_start = 11607
+    _REPARTITIONBYEXPRESSION._serialized_end = 11810
+    _MAPPARTITIONS._serialized_start = 11813
+    _MAPPARTITIONS._serialized_end = 11994
+    _GROUPMAP._serialized_start = 11997
+    _GROUPMAP._serialized_end = 12632
+    _COGROUPMAP._serialized_start = 12635
+    _COGROUPMAP._serialized_end = 13161
+    _APPLYINPANDASWITHSTATE._serialized_start = 13164
+    _APPLYINPANDASWITHSTATE._serialized_end = 13521
+    _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 13524
+    _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 13768
+    _PYTHONUDTF._serialized_start = 13771
+    _PYTHONUDTF._serialized_end = 13948
+    _COLLECTMETRICS._serialized_start = 13951
+    _COLLECTMETRICS._serialized_end = 14087
+    _PARSE._serialized_start = 14090
+    _PARSE._serialized_end = 14478
     _PARSE_OPTIONSENTRY._serialized_start = 4291
     _PARSE_OPTIONSENTRY._serialized_end = 4349
-    _PARSE_PARSEFORMAT._serialized_start = 14195
-    _PARSE_PARSEFORMAT._serialized_end = 14283
-    _ASOFJOIN._serialized_start = 14297
-    _ASOFJOIN._serialized_end = 14772
+    _PARSE_PARSEFORMAT._serialized_start = 14379
+    _PARSE_PARSEFORMAT._serialized_end = 14467
+    _ASOFJOIN._serialized_start = 14481
+    _ASOFJOIN._serialized_end = 14956
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 5bca4f21b2e..f8b7a2ad1cd 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -1380,6 +1380,7 @@ class Aggregate(google.protobuf.message.Message):
         GROUP_TYPE_ROLLUP: Aggregate._GroupType.ValueType  # 2
         GROUP_TYPE_CUBE: Aggregate._GroupType.ValueType  # 3
         GROUP_TYPE_PIVOT: Aggregate._GroupType.ValueType  # 4
+        GROUP_TYPE_GROUPING_SETS: Aggregate._GroupType.ValueType  # 5
 
     class GroupType(_GroupType, metaclass=_GroupTypeEnumTypeWrapper): ...
     GROUP_TYPE_UNSPECIFIED: Aggregate.GroupType.ValueType  # 0
@@ -1387,6 +1388,7 @@ class Aggregate(google.protobuf.message.Message):
     GROUP_TYPE_ROLLUP: Aggregate.GroupType.ValueType  # 2
     GROUP_TYPE_CUBE: Aggregate.GroupType.ValueType  # 3
     GROUP_TYPE_PIVOT: Aggregate.GroupType.ValueType  # 4
+    GROUP_TYPE_GROUPING_SETS: Aggregate.GroupType.ValueType  # 5
 
     class Pivot(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -1423,11 +1425,35 @@ class Aggregate(google.protobuf.message.Message):
             self, field_name: typing_extensions.Literal["col", b"col", 
"values", b"values"]
         ) -> None: ...
 
+    class GroupingSets(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        GROUPING_SET_FIELD_NUMBER: builtins.int
+        @property
+        def grouping_set(
+            self,
+        ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+            pyspark.sql.connect.proto.expressions_pb2.Expression
+        ]:
+            """(Required) Individual grouping set"""
+        def __init__(
+            self,
+            *,
+            grouping_set: collections.abc.Iterable[
+                pyspark.sql.connect.proto.expressions_pb2.Expression
+            ]
+            | None = ...,
+        ) -> None: ...
+        def ClearField(
+            self, field_name: typing_extensions.Literal["grouping_set", 
b"grouping_set"]
+        ) -> None: ...
+
     INPUT_FIELD_NUMBER: builtins.int
     GROUP_TYPE_FIELD_NUMBER: builtins.int
     GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
     AGGREGATE_EXPRESSIONS_FIELD_NUMBER: builtins.int
     PIVOT_FIELD_NUMBER: builtins.int
+    GROUPING_SETS_FIELD_NUMBER: builtins.int
     @property
     def input(self) -> global___Relation:
         """(Required) Input relation for a RelationalGroupedDataset."""
@@ -1450,6 +1476,13 @@ class Aggregate(google.protobuf.message.Message):
     @property
     def pivot(self) -> global___Aggregate.Pivot:
         """(Optional) Pivots a column of the current `DataFrame` and performs 
the specified aggregation."""
+    @property
+    def grouping_sets(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        global___Aggregate.GroupingSets
+    ]:
+        """(Optional) List of values that will be translated to columns in the 
output DataFrame."""
     def __init__(
         self,
         *,
@@ -1464,6 +1497,7 @@ class Aggregate(google.protobuf.message.Message):
         ]
         | None = ...,
         pivot: global___Aggregate.Pivot | None = ...,
+        grouping_sets: 
collections.abc.Iterable[global___Aggregate.GroupingSets] | None = ...,
     ) -> None: ...
     def HasField(
         self, field_name: typing_extensions.Literal["input", b"input", 
"pivot", b"pivot"]
@@ -1477,6 +1511,8 @@ class Aggregate(google.protobuf.message.Message):
             b"group_type",
             "grouping_expressions",
             b"grouping_expressions",
+            "grouping_sets",
+            b"grouping_sets",
             "input",
             b"input",
             "pivot",
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 383a5566ded..82087adc82f 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -4204,7 +4204,6 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
 
         return GroupedData(jgd, self)
 
-    # TODO(SPARK-46048): Add it to Python Spark Connect client.
     def groupingSets(
         self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: 
"ColumnOrName"
     ) -> "GroupedData":


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to