This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d14410c6777 [SPARK-46048][PYTHON][CONNECT] Support DataFrame.groupingSets in Python Spark Connect d14410c6777 is described below commit d14410c6777e7de7f61e1957fab749da2793f4b8 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Thu Nov 23 16:38:52 2023 +0900 [SPARK-46048][PYTHON][CONNECT] Support DataFrame.groupingSets in Python Spark Connect ### What changes were proposed in this pull request? This PR adds `DataFrame.groupingSets` in Python Spark Connect. ### Why are the changes needed? For feature parity with non-Spark Connect. ### Does this PR introduce _any_ user-facing change? Yes, it adds the new API `DataFframe.groupingSets` in Python Spark Connect. ### How was this patch tested? Unittests were added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43967 from HyukjinKwon/SPARK-46048. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 9 + .../org/apache/spark/sql/connect/dsl/package.scala | 21 +++ .../sql/connect/planner/SparkConnectPlanner.scala | 11 ++ .../connect/planner/SparkConnectProtoSuite.scala | 12 ++ python/pyspark/sql/connect/dataframe.py | 39 +++++ python/pyspark/sql/connect/group.py | 16 +- python/pyspark/sql/connect/plan.py | 23 ++- python/pyspark/sql/connect/proto/relations_pb2.py | 194 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 36 ++++ python/pyspark/sql/dataframe.py | 1 - 10 files changed, 262 insertions(+), 100 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index deb33978386..43f692671df 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -327,12 +327,16 @@ message Aggregate { // (Optional) Pivots a column of the current `DataFrame` and performs the specified aggregation. Pivot pivot = 5; + // (Optional) List of values that will be translated to columns in the output DataFrame. + repeated GroupingSets grouping_sets = 6; + enum GroupType { GROUP_TYPE_UNSPECIFIED = 0; GROUP_TYPE_GROUPBY = 1; GROUP_TYPE_ROLLUP = 2; GROUP_TYPE_CUBE = 3; GROUP_TYPE_PIVOT = 4; + GROUP_TYPE_GROUPING_SETS = 5; } message Pivot { @@ -345,6 +349,11 @@ message Aggregate { // the distinct values of the column. repeated Expression.Literal values = 2; } + + message GroupingSets { + // (Required) Individual grouping set + repeated Expression grouping_set = 1; + } } // Relation of type [[Sort]]. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 5fd1a035385..18c71ae4ace 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -800,6 +800,27 @@ package object dsl { Relation.newBuilder().setAggregate(agg.build()).build() } + def groupingSets(groupingSets: Seq[Seq[Expression]], groupingExprs: Expression*)( + aggregateExprs: Expression*): Relation = { + val agg = Aggregate.newBuilder() + agg.setInput(logicalPlan) + agg.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS) + for (groupingSet <- groupingSets) { + val groupingSetMsg = Aggregate.GroupingSets.newBuilder() + for (groupCol <- groupingSet) { + groupingSetMsg.addGroupingSet(groupCol) + } + agg.addGroupingSets(groupingSetMsg) + } + for (groupingExpr <- groupingExprs) { + agg.addGroupingExpressions(groupingExpr) + } + for (aggregateExpr <- aggregateExprs) { + agg.addAggregateExpressions(aggregateExpr) + } + Relation.newBuilder().setAggregate(agg.build()).build() + } + def except(otherPlan: Relation, isAll: Boolean): Relation = { Relation .newBuilder() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 4a0aa7e5589..95c5acc803d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2445,6 +2445,17 @@ class SparkConnectPlanner( aggregates = aggExprs, child = input) + case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS => + val groupingSetsExprs = rel.getGroupingSetsList.asScala.toSeq.map { getGroupingSets => + getGroupingSets.getGroupingSetList.asScala.toSeq.map(transformExpression) + } + logical.Aggregate( + groupingExpressions = Seq( + GroupingSets( + groupingSets = groupingSetsExprs, + userGivenGroupByExprs = groupingExprs)), + aggregateExpressions = aliasedAgg, + child = input) case other => throw InvalidPlanInput(s"Unknown Group Type $other") } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index c54aa496c66..0b27ccdbef8 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -307,6 +307,18 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { comparePlans(connectPlan2, sparkPlan2) } + test("GroupingSets expressions") { + val connectPlan1 = + connectTestRelation.groupingSets(Seq(Seq("id".protoAttr), Seq.empty), "id".protoAttr)( + proto_min(proto.Expression.newBuilder().setLiteral(toLiteralProto(1)).build()) + .as("agg1")) + val sparkPlan1 = + sparkTestRelation + .groupingSets(Seq(Seq(Column("id")), Seq.empty), Column("id")) + .agg(min(lit(1)).as("agg1")) + comparePlans(connectPlan1, sparkPlan1) + } + test("Test as(alias: String)") { val connectPlan = connectTestRelation.as("target_table") val sparkPlan = sparkTestRelation.as("target_table") diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index c7b51205363..b3bec44428b 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -550,6 +550,45 @@ class DataFrame: cube.__doc__ = PySparkDataFrame.cube.__doc__ + def groupingSets( + self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: "ColumnOrName" + ) -> "GroupedData": + gsets: List[List[Column]] = [] + for grouping_set in groupingSets: + gset: List[Column] = [] + for c in grouping_set: + if isinstance(c, Column): + gset.append(c) + elif isinstance(c, str): + gset.append(self[c]) + else: + raise PySparkTypeError( + error_class="NOT_COLUMN_OR_STR", + message_parameters={ + "arg_name": "groupingSets", + "arg_type": type(c).__name__, + }, + ) + gsets.append(gset) + + gcols: List[Column] = [] + for c in cols: + if isinstance(c, Column): + gcols.append(c) + elif isinstance(c, str): + gcols.append(self[c]) + else: + raise PySparkTypeError( + error_class="NOT_COLUMN_OR_STR", + message_parameters={"arg_name": "cols", "arg_type": type(c).__name__}, + ) + + return GroupedData( + df=self, group_type="grouping_sets", grouping_cols=gcols, grouping_sets=gsets + ) + + groupingSets.__doc__ = PySparkDataFrame.groupingSets.__doc__ + @overload def head(self) -> Optional[Row]: ... diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index 7b71a43c112..481b7981a15 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -63,13 +63,20 @@ class GroupedData: grouping_cols: Sequence["Column"], pivot_col: Optional["Column"] = None, pivot_values: Optional[Sequence["LiteralType"]] = None, + grouping_sets: Optional[Sequence[Sequence["Column"]]] = None, ) -> None: from pyspark.sql.connect.dataframe import DataFrame assert isinstance(df, DataFrame) self._df = df - assert isinstance(group_type, str) and group_type in ["groupby", "rollup", "cube", "pivot"] + assert isinstance(group_type, str) and group_type in [ + "groupby", + "rollup", + "cube", + "pivot", + "grouping_sets", + ] self._group_type = group_type assert isinstance(grouping_cols, list) and all(isinstance(g, Column) for g in grouping_cols) @@ -83,6 +90,11 @@ class GroupedData: self._pivot_col = pivot_col self._pivot_values = pivot_values + self._grouping_sets: Optional[Sequence[Sequence["Column"]]] = None + if group_type == "grouping_sets": + assert grouping_sets is None or isinstance(grouping_sets, list) + self._grouping_sets = grouping_sets + def __repr__(self) -> str: # the expressions are not resolved here, # so the string representation can be different from vanilla PySpark. @@ -130,6 +142,7 @@ class GroupedData: aggregate_cols=aggregate_cols, pivot_col=self._pivot_col, pivot_values=self._pivot_values, + grouping_sets=self._grouping_sets, ), session=self._df._session, ) @@ -171,6 +184,7 @@ class GroupedData: aggregate_cols=[_invoke_function(function, col(c)) for c in agg_cols], pivot_col=self._pivot_col, pivot_values=self._pivot_values, + grouping_sets=self._grouping_sets, ), session=self._df._session, ) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 607d1429a9e..7d63f8714a9 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -778,10 +778,17 @@ class Aggregate(LogicalPlan): aggregate_cols: Sequence[Column], pivot_col: Optional[Column], pivot_values: Optional[Sequence[Any]], + grouping_sets: Optional[Sequence[Sequence[Column]]], ) -> None: super().__init__(child) - assert isinstance(group_type, str) and group_type in ["groupby", "rollup", "cube", "pivot"] + assert isinstance(group_type, str) and group_type in [ + "groupby", + "rollup", + "cube", + "pivot", + "grouping_sets", + ] self._group_type = group_type assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols) @@ -795,12 +802,16 @@ class Aggregate(LogicalPlan): if group_type == "pivot": assert pivot_col is not None and isinstance(pivot_col, Column) assert pivot_values is None or isinstance(pivot_values, list) + elif group_type == "grouping_sets": + assert grouping_sets is None or isinstance(grouping_sets, list) else: assert pivot_col is None assert pivot_values is None + assert grouping_sets is None self._pivot_col = pivot_col self._pivot_values = pivot_values + self._grouping_sets = grouping_sets def plan(self, session: "SparkConnectClient") -> proto.Relation: from pyspark.sql.connect.functions import lit @@ -829,7 +840,15 @@ class Aggregate(LogicalPlan): plan.aggregate.pivot.values.extend( [lit(v).to_plan(session).literal for v in self._pivot_values] ) - + elif self._group_type == "grouping_sets": + plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS + assert self._grouping_sets is not None + for grouping_set in self._grouping_sets: + plan.aggregate.grouping_sets.append( + proto.Aggregate.GroupingSets( + grouping_set=[c.to_plan(session) for c in grouping_set] + ) + ) return plan diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index fc70cdea402..f79ee786afb 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\x9a\x19\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\x9a\x19\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -104,101 +104,103 @@ if _descriptor._USE_C_DESCRIPTORS == False: _TAIL._serialized_start = 6182 _TAIL._serialized_end = 6257 _AGGREGATE._serialized_start = 6260 - _AGGREGATE._serialized_end = 6842 - _AGGREGATE_PIVOT._serialized_start = 6599 - _AGGREGATE_PIVOT._serialized_end = 6710 - _AGGREGATE_GROUPTYPE._serialized_start = 6713 - _AGGREGATE_GROUPTYPE._serialized_end = 6842 - _SORT._serialized_start = 6845 - _SORT._serialized_end = 7005 - _DROP._serialized_start = 7008 - _DROP._serialized_end = 7149 - _DEDUPLICATE._serialized_start = 7152 - _DEDUPLICATE._serialized_end = 7392 - _LOCALRELATION._serialized_start = 7394 - _LOCALRELATION._serialized_end = 7483 - _CACHEDLOCALRELATION._serialized_start = 7485 - _CACHEDLOCALRELATION._serialized_end = 7557 - _CACHEDREMOTERELATION._serialized_start = 7559 - _CACHEDREMOTERELATION._serialized_end = 7614 - _SAMPLE._serialized_start = 7617 - _SAMPLE._serialized_end = 7890 - _RANGE._serialized_start = 7893 - _RANGE._serialized_end = 8038 - _SUBQUERYALIAS._serialized_start = 8040 - _SUBQUERYALIAS._serialized_end = 8154 - _REPARTITION._serialized_start = 8157 - _REPARTITION._serialized_end = 8299 - _SHOWSTRING._serialized_start = 8302 - _SHOWSTRING._serialized_end = 8444 - _HTMLSTRING._serialized_start = 8446 - _HTMLSTRING._serialized_end = 8560 - _STATSUMMARY._serialized_start = 8562 - _STATSUMMARY._serialized_end = 8654 - _STATDESCRIBE._serialized_start = 8656 - _STATDESCRIBE._serialized_end = 8737 - _STATCROSSTAB._serialized_start = 8739 - _STATCROSSTAB._serialized_end = 8840 - _STATCOV._serialized_start = 8842 - _STATCOV._serialized_end = 8938 - _STATCORR._serialized_start = 8941 - _STATCORR._serialized_end = 9078 - _STATAPPROXQUANTILE._serialized_start = 9081 - _STATAPPROXQUANTILE._serialized_end = 9245 - _STATFREQITEMS._serialized_start = 9247 - _STATFREQITEMS._serialized_end = 9372 - _STATSAMPLEBY._serialized_start = 9375 - _STATSAMPLEBY._serialized_end = 9684 - _STATSAMPLEBY_FRACTION._serialized_start = 9576 - _STATSAMPLEBY_FRACTION._serialized_end = 9675 - _NAFILL._serialized_start = 9687 - _NAFILL._serialized_end = 9821 - _NADROP._serialized_start = 9824 - _NADROP._serialized_end = 9958 - _NAREPLACE._serialized_start = 9961 - _NAREPLACE._serialized_end = 10257 - _NAREPLACE_REPLACEMENT._serialized_start = 10116 - _NAREPLACE_REPLACEMENT._serialized_end = 10257 - _TODF._serialized_start = 10259 - _TODF._serialized_end = 10347 - _WITHCOLUMNSRENAMED._serialized_start = 10350 - _WITHCOLUMNSRENAMED._serialized_end = 10589 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10522 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10589 - _WITHCOLUMNS._serialized_start = 10591 - _WITHCOLUMNS._serialized_end = 10710 - _WITHWATERMARK._serialized_start = 10713 - _WITHWATERMARK._serialized_end = 10847 - _HINT._serialized_start = 10850 - _HINT._serialized_end = 10982 - _UNPIVOT._serialized_start = 10985 - _UNPIVOT._serialized_end = 11312 - _UNPIVOT_VALUES._serialized_start = 11242 - _UNPIVOT_VALUES._serialized_end = 11301 - _TOSCHEMA._serialized_start = 11314 - _TOSCHEMA._serialized_end = 11420 - _REPARTITIONBYEXPRESSION._serialized_start = 11423 - _REPARTITIONBYEXPRESSION._serialized_end = 11626 - _MAPPARTITIONS._serialized_start = 11629 - _MAPPARTITIONS._serialized_end = 11810 - _GROUPMAP._serialized_start = 11813 - _GROUPMAP._serialized_end = 12448 - _COGROUPMAP._serialized_start = 12451 - _COGROUPMAP._serialized_end = 12977 - _APPLYINPANDASWITHSTATE._serialized_start = 12980 - _APPLYINPANDASWITHSTATE._serialized_end = 13337 - _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 13340 - _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 13584 - _PYTHONUDTF._serialized_start = 13587 - _PYTHONUDTF._serialized_end = 13764 - _COLLECTMETRICS._serialized_start = 13767 - _COLLECTMETRICS._serialized_end = 13903 - _PARSE._serialized_start = 13906 - _PARSE._serialized_end = 14294 + _AGGREGATE._serialized_end = 7026 + _AGGREGATE_PIVOT._serialized_start = 6675 + _AGGREGATE_PIVOT._serialized_end = 6786 + _AGGREGATE_GROUPINGSETS._serialized_start = 6788 + _AGGREGATE_GROUPINGSETS._serialized_end = 6864 + _AGGREGATE_GROUPTYPE._serialized_start = 6867 + _AGGREGATE_GROUPTYPE._serialized_end = 7026 + _SORT._serialized_start = 7029 + _SORT._serialized_end = 7189 + _DROP._serialized_start = 7192 + _DROP._serialized_end = 7333 + _DEDUPLICATE._serialized_start = 7336 + _DEDUPLICATE._serialized_end = 7576 + _LOCALRELATION._serialized_start = 7578 + _LOCALRELATION._serialized_end = 7667 + _CACHEDLOCALRELATION._serialized_start = 7669 + _CACHEDLOCALRELATION._serialized_end = 7741 + _CACHEDREMOTERELATION._serialized_start = 7743 + _CACHEDREMOTERELATION._serialized_end = 7798 + _SAMPLE._serialized_start = 7801 + _SAMPLE._serialized_end = 8074 + _RANGE._serialized_start = 8077 + _RANGE._serialized_end = 8222 + _SUBQUERYALIAS._serialized_start = 8224 + _SUBQUERYALIAS._serialized_end = 8338 + _REPARTITION._serialized_start = 8341 + _REPARTITION._serialized_end = 8483 + _SHOWSTRING._serialized_start = 8486 + _SHOWSTRING._serialized_end = 8628 + _HTMLSTRING._serialized_start = 8630 + _HTMLSTRING._serialized_end = 8744 + _STATSUMMARY._serialized_start = 8746 + _STATSUMMARY._serialized_end = 8838 + _STATDESCRIBE._serialized_start = 8840 + _STATDESCRIBE._serialized_end = 8921 + _STATCROSSTAB._serialized_start = 8923 + _STATCROSSTAB._serialized_end = 9024 + _STATCOV._serialized_start = 9026 + _STATCOV._serialized_end = 9122 + _STATCORR._serialized_start = 9125 + _STATCORR._serialized_end = 9262 + _STATAPPROXQUANTILE._serialized_start = 9265 + _STATAPPROXQUANTILE._serialized_end = 9429 + _STATFREQITEMS._serialized_start = 9431 + _STATFREQITEMS._serialized_end = 9556 + _STATSAMPLEBY._serialized_start = 9559 + _STATSAMPLEBY._serialized_end = 9868 + _STATSAMPLEBY_FRACTION._serialized_start = 9760 + _STATSAMPLEBY_FRACTION._serialized_end = 9859 + _NAFILL._serialized_start = 9871 + _NAFILL._serialized_end = 10005 + _NADROP._serialized_start = 10008 + _NADROP._serialized_end = 10142 + _NAREPLACE._serialized_start = 10145 + _NAREPLACE._serialized_end = 10441 + _NAREPLACE_REPLACEMENT._serialized_start = 10300 + _NAREPLACE_REPLACEMENT._serialized_end = 10441 + _TODF._serialized_start = 10443 + _TODF._serialized_end = 10531 + _WITHCOLUMNSRENAMED._serialized_start = 10534 + _WITHCOLUMNSRENAMED._serialized_end = 10773 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10706 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10773 + _WITHCOLUMNS._serialized_start = 10775 + _WITHCOLUMNS._serialized_end = 10894 + _WITHWATERMARK._serialized_start = 10897 + _WITHWATERMARK._serialized_end = 11031 + _HINT._serialized_start = 11034 + _HINT._serialized_end = 11166 + _UNPIVOT._serialized_start = 11169 + _UNPIVOT._serialized_end = 11496 + _UNPIVOT_VALUES._serialized_start = 11426 + _UNPIVOT_VALUES._serialized_end = 11485 + _TOSCHEMA._serialized_start = 11498 + _TOSCHEMA._serialized_end = 11604 + _REPARTITIONBYEXPRESSION._serialized_start = 11607 + _REPARTITIONBYEXPRESSION._serialized_end = 11810 + _MAPPARTITIONS._serialized_start = 11813 + _MAPPARTITIONS._serialized_end = 11994 + _GROUPMAP._serialized_start = 11997 + _GROUPMAP._serialized_end = 12632 + _COGROUPMAP._serialized_start = 12635 + _COGROUPMAP._serialized_end = 13161 + _APPLYINPANDASWITHSTATE._serialized_start = 13164 + _APPLYINPANDASWITHSTATE._serialized_end = 13521 + _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 13524 + _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 13768 + _PYTHONUDTF._serialized_start = 13771 + _PYTHONUDTF._serialized_end = 13948 + _COLLECTMETRICS._serialized_start = 13951 + _COLLECTMETRICS._serialized_end = 14087 + _PARSE._serialized_start = 14090 + _PARSE._serialized_end = 14478 _PARSE_OPTIONSENTRY._serialized_start = 4291 _PARSE_OPTIONSENTRY._serialized_end = 4349 - _PARSE_PARSEFORMAT._serialized_start = 14195 - _PARSE_PARSEFORMAT._serialized_end = 14283 - _ASOFJOIN._serialized_start = 14297 - _ASOFJOIN._serialized_end = 14772 + _PARSE_PARSEFORMAT._serialized_start = 14379 + _PARSE_PARSEFORMAT._serialized_end = 14467 + _ASOFJOIN._serialized_start = 14481 + _ASOFJOIN._serialized_end = 14956 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 5bca4f21b2e..f8b7a2ad1cd 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -1380,6 +1380,7 @@ class Aggregate(google.protobuf.message.Message): GROUP_TYPE_ROLLUP: Aggregate._GroupType.ValueType # 2 GROUP_TYPE_CUBE: Aggregate._GroupType.ValueType # 3 GROUP_TYPE_PIVOT: Aggregate._GroupType.ValueType # 4 + GROUP_TYPE_GROUPING_SETS: Aggregate._GroupType.ValueType # 5 class GroupType(_GroupType, metaclass=_GroupTypeEnumTypeWrapper): ... GROUP_TYPE_UNSPECIFIED: Aggregate.GroupType.ValueType # 0 @@ -1387,6 +1388,7 @@ class Aggregate(google.protobuf.message.Message): GROUP_TYPE_ROLLUP: Aggregate.GroupType.ValueType # 2 GROUP_TYPE_CUBE: Aggregate.GroupType.ValueType # 3 GROUP_TYPE_PIVOT: Aggregate.GroupType.ValueType # 4 + GROUP_TYPE_GROUPING_SETS: Aggregate.GroupType.ValueType # 5 class Pivot(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -1423,11 +1425,35 @@ class Aggregate(google.protobuf.message.Message): self, field_name: typing_extensions.Literal["col", b"col", "values", b"values"] ) -> None: ... + class GroupingSets(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + GROUPING_SET_FIELD_NUMBER: builtins.int + @property + def grouping_set( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ]: + """(Required) Individual grouping set""" + def __init__( + self, + *, + grouping_set: collections.abc.Iterable[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ] + | None = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["grouping_set", b"grouping_set"] + ) -> None: ... + INPUT_FIELD_NUMBER: builtins.int GROUP_TYPE_FIELD_NUMBER: builtins.int GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int AGGREGATE_EXPRESSIONS_FIELD_NUMBER: builtins.int PIVOT_FIELD_NUMBER: builtins.int + GROUPING_SETS_FIELD_NUMBER: builtins.int @property def input(self) -> global___Relation: """(Required) Input relation for a RelationalGroupedDataset.""" @@ -1450,6 +1476,13 @@ class Aggregate(google.protobuf.message.Message): @property def pivot(self) -> global___Aggregate.Pivot: """(Optional) Pivots a column of the current `DataFrame` and performs the specified aggregation.""" + @property + def grouping_sets( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + global___Aggregate.GroupingSets + ]: + """(Optional) List of values that will be translated to columns in the output DataFrame.""" def __init__( self, *, @@ -1464,6 +1497,7 @@ class Aggregate(google.protobuf.message.Message): ] | None = ..., pivot: global___Aggregate.Pivot | None = ..., + grouping_sets: collections.abc.Iterable[global___Aggregate.GroupingSets] | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal["input", b"input", "pivot", b"pivot"] @@ -1477,6 +1511,8 @@ class Aggregate(google.protobuf.message.Message): b"group_type", "grouping_expressions", b"grouping_expressions", + "grouping_sets", + b"grouping_sets", "input", b"input", "pivot", diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 383a5566ded..82087adc82f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -4204,7 +4204,6 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return GroupedData(jgd, self) - # TODO(SPARK-46048): Add it to Python Spark Connect client. def groupingSets( self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: "ColumnOrName" ) -> "GroupedData": --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org