This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5211f6b140a [SPARK-46085][CONNECT] Dataset.groupingSets in Scala Spark 
Connect client
5211f6b140a is described below

commit 5211f6b140a74bd28f7e05934508bdafdbe7f237
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Fri Nov 24 17:52:23 2023 -0800

    [SPARK-46085][CONNECT] Dataset.groupingSets in Scala Spark Connect client
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add `Dataset.groupingsets` API added from 
https://github.com/apache/spark/pull/43813 to Scala Spark Connect cleint.
    
    ### Why are the changes needed?
    
    For feature parity.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it adds a new API to Scala Spark Connect client.
    
    ### How was this patch tested?
    
    Unittest was added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43995 from HyukjinKwon/SPARK-46085.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  35 +++++++++++++++
 .../spark/sql/RelationalGroupedDataset.scala       |   8 +++-
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   6 +++
 .../explain-results/groupingSets.explain           |   4 ++
 .../query-tests/queries/groupingSets.json          |  50 +++++++++++++++++++++
 .../query-tests/queries/groupingSets.proto.bin     | Bin 0 -> 106 bytes
 6 files changed, 102 insertions(+), 1 deletion(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index a1e57226e53..d760c9d9769 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1532,6 +1532,41 @@ class Dataset[T] private[sql] (
       proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
   }
 
+  /**
+   * Create multi-dimensional aggregation for the current Dataset using the 
specified grouping
+   * sets, so we can run aggregation on them. See [[RelationalGroupedDataset]] 
for all the
+   * available aggregate functions.
+   *
+   * {{{
+   *   // Compute the average for all numeric columns group by specific 
grouping sets.
+   *   ds.groupingSets(Seq(Seq($"department", $"group"), Seq()), 
$"department", $"group").avg()
+   *
+   *   // Compute the max age and average salary, group by specific grouping 
sets.
+   *   ds.groupingSets(Seq($"department", $"gender"), Seq()), $"department", 
$"group").agg(Map(
+   *     "salary" -> "avg",
+   *     "age" -> "max"
+   *   ))
+   * }}}
+   *
+   * @group untypedrel
+   * @since 4.0.0
+   */
+  @scala.annotation.varargs
+  def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): 
RelationalGroupedDataset = {
+    val groupingSetMsgs = groupingSets.map { groupingSet =>
+      val groupingSetMsg = proto.Aggregate.GroupingSets.newBuilder()
+      for (groupCol <- groupingSet) {
+        groupingSetMsg.addGroupingSet(groupCol.expr)
+      }
+      groupingSetMsg.build()
+    }
+    new RelationalGroupedDataset(
+      toDF(),
+      cols,
+      proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS,
+      groupingSets = Some(groupingSetMsgs))
+  }
+
   /**
    * (Scala-specific) Aggregates on the entire Dataset without groups.
    * {{{
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 5ed97e45c77..776a6231eae 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -39,7 +39,8 @@ class RelationalGroupedDataset private[sql] (
     private[sql] val df: DataFrame,
     private[sql] val groupingExprs: Seq[Column],
     groupType: proto.Aggregate.GroupType,
-    pivot: Option[proto.Aggregate.Pivot] = None) {
+    pivot: Option[proto.Aggregate.Pivot] = None,
+    groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) {
 
   private[this] def toDF(aggExprs: Seq[Column]): DataFrame = {
     df.sparkSession.newDataFrame { builder =>
@@ -60,6 +61,11 @@ class RelationalGroupedDataset private[sql] (
           builder.getAggregateBuilder
             .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT)
             .setPivot(pivot.get)
+        case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS =>
+          assert(groupingSets.isDefined)
+          val aggBuilder = builder.getAggregateBuilder
+            .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS)
+          groupingSets.get.foreach(aggBuilder.addGroupingSets)
         case g => throw new UnsupportedOperationException(g.toString)
       }
     }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 5cc63bc45a0..c5c917ebfa9 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -3017,6 +3017,12 @@ class PlanGenerationTestSuite
     simple.groupBy(Column("id")).pivot("a").agg(functions.count(Column("b")))
   }
 
+  test("groupingSets") {
+    simple
+      .groupingSets(Seq(Seq(fn.col("a")), Seq.empty[Column]), fn.col("a"))
+      .agg("a" -> "max", "a" -> "count")
+  }
+
   test("width_bucket") {
     simple.select(fn.width_bucket(fn.col("b"), fn.col("b"), fn.col("b"), 
fn.col("a")))
   }
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/groupingSets.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/groupingSets.explain
new file mode 100644
index 00000000000..1e3fe1a987e
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/groupingSets.explain
@@ -0,0 +1,4 @@
+Aggregate [a#0, spark_grouping_id#0L], [a#0, max(a#0) AS max(a)#0, count(a#0) 
AS count(a)#0L]
++- Expand [[id#0L, a#0, b#0, a#0, 0], [id#0L, a#0, b#0, null, 1]], [id#0L, 
a#0, b#0, a#0, spark_grouping_id#0L]
+   +- Project [id#0L, a#0, b#0, a#0 AS a#0]
+      +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.json
new file mode 100644
index 00000000000..6e84824ec7a
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.json
@@ -0,0 +1,50 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "aggregate": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "groupType": "GROUP_TYPE_GROUPING_SETS",
+    "groupingExpressions": [{
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "a"
+      }
+    }],
+    "aggregateExpressions": [{
+      "unresolvedFunction": {
+        "functionName": "max",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "a",
+            "planId": "0"
+          }
+        }]
+      }
+    }, {
+      "unresolvedFunction": {
+        "functionName": "count",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "a",
+            "planId": "0"
+          }
+        }]
+      }
+    }],
+    "groupingSets": [{
+      "groupingSet": [{
+        "unresolvedAttribute": {
+          "unparsedIdentifier": "a"
+        }
+      }]
+    }, {
+    }]
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin
new file mode 100644
index 00000000000..ce029409670
Binary files /dev/null and 
b/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin
 differ


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to