This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5211f6b140a [SPARK-46085][CONNECT] Dataset.groupingSets in Scala Spark Connect client 5211f6b140a is described below commit 5211f6b140a74bd28f7e05934508bdafdbe7f237 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Fri Nov 24 17:52:23 2023 -0800 [SPARK-46085][CONNECT] Dataset.groupingSets in Scala Spark Connect client ### What changes were proposed in this pull request? This PR proposes to add `Dataset.groupingsets` API added from https://github.com/apache/spark/pull/43813 to Scala Spark Connect cleint. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? Yes, it adds a new API to Scala Spark Connect client. ### How was this patch tested? Unittest was added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43995 from HyukjinKwon/SPARK-46085. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 35 +++++++++++++++ .../spark/sql/RelationalGroupedDataset.scala | 8 +++- .../apache/spark/sql/PlanGenerationTestSuite.scala | 6 +++ .../explain-results/groupingSets.explain | 4 ++ .../query-tests/queries/groupingSets.json | 50 +++++++++++++++++++++ .../query-tests/queries/groupingSets.proto.bin | Bin 0 -> 106 bytes 6 files changed, 102 insertions(+), 1 deletion(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index a1e57226e53..d760c9d9769 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1532,6 +1532,41 @@ class Dataset[T] private[sql] ( proto.Aggregate.GroupType.GROUP_TYPE_CUBE) } + /** + * Create multi-dimensional aggregation for the current Dataset using the specified grouping + * sets, so we can run aggregation on them. See [[RelationalGroupedDataset]] for all the + * available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns group by specific grouping sets. + * ds.groupingSets(Seq(Seq($"department", $"group"), Seq()), $"department", $"group").avg() + * + * // Compute the max age and average salary, group by specific grouping sets. + * ds.groupingSets(Seq($"department", $"gender"), Seq()), $"department", $"group").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 4.0.0 + */ + @scala.annotation.varargs + def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RelationalGroupedDataset = { + val groupingSetMsgs = groupingSets.map { groupingSet => + val groupingSetMsg = proto.Aggregate.GroupingSets.newBuilder() + for (groupCol <- groupingSet) { + groupingSetMsg.addGroupingSet(groupCol.expr) + } + groupingSetMsg.build() + } + new RelationalGroupedDataset( + toDF(), + cols, + proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS, + groupingSets = Some(groupingSetMsgs)) + } + /** * (Scala-specific) Aggregates on the entire Dataset without groups. * {{{ diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 5ed97e45c77..776a6231eae 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -39,7 +39,8 @@ class RelationalGroupedDataset private[sql] ( private[sql] val df: DataFrame, private[sql] val groupingExprs: Seq[Column], groupType: proto.Aggregate.GroupType, - pivot: Option[proto.Aggregate.Pivot] = None) { + pivot: Option[proto.Aggregate.Pivot] = None, + groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) { private[this] def toDF(aggExprs: Seq[Column]): DataFrame = { df.sparkSession.newDataFrame { builder => @@ -60,6 +61,11 @@ class RelationalGroupedDataset private[sql] ( builder.getAggregateBuilder .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT) .setPivot(pivot.get) + case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS => + assert(groupingSets.isDefined) + val aggBuilder = builder.getAggregateBuilder + .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS) + groupingSets.get.foreach(aggBuilder.addGroupingSets) case g => throw new UnsupportedOperationException(g.toString) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 5cc63bc45a0..c5c917ebfa9 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -3017,6 +3017,12 @@ class PlanGenerationTestSuite simple.groupBy(Column("id")).pivot("a").agg(functions.count(Column("b"))) } + test("groupingSets") { + simple + .groupingSets(Seq(Seq(fn.col("a")), Seq.empty[Column]), fn.col("a")) + .agg("a" -> "max", "a" -> "count") + } + test("width_bucket") { simple.select(fn.width_bucket(fn.col("b"), fn.col("b"), fn.col("b"), fn.col("a"))) } diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/groupingSets.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/groupingSets.explain new file mode 100644 index 00000000000..1e3fe1a987e --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/groupingSets.explain @@ -0,0 +1,4 @@ +Aggregate [a#0, spark_grouping_id#0L], [a#0, max(a#0) AS max(a)#0, count(a#0) AS count(a)#0L] ++- Expand [[id#0L, a#0, b#0, a#0, 0], [id#0L, a#0, b#0, null, 1]], [id#0L, a#0, b#0, a#0, spark_grouping_id#0L] + +- Project [id#0L, a#0, b#0, a#0 AS a#0] + +- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.json b/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.json new file mode 100644 index 00000000000..6e84824ec7a --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.json @@ -0,0 +1,50 @@ +{ + "common": { + "planId": "1" + }, + "aggregate": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "groupType": "GROUP_TYPE_GROUPING_SETS", + "groupingExpressions": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }], + "aggregateExpressions": [{ + "unresolvedFunction": { + "functionName": "max", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a", + "planId": "0" + } + }] + } + }, { + "unresolvedFunction": { + "functionName": "count", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a", + "planId": "0" + } + }] + } + }], + "groupingSets": [{ + "groupingSet": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }] + }, { + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin new file mode 100644 index 00000000000..ce029409670 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin differ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org