This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new cc471a52d16 [SPARK-42468][CONNECT] Implement agg by (String, String)* cc471a52d16 is described below commit cc471a52d162d0e4d4063372253ed06a62f5cb19 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Thu Feb 16 23:02:51 2023 -0400 [SPARK-42468][CONNECT] Implement agg by (String, String)* ### What changes were proposed in this pull request? Starting to support basic aggregation in Scala client. The first step is to support aggregation by strings. ### Why are the changes needed? API coverage ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #40057 from amaliujia/rw-agg. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 23 ++++ .../spark/sql/RelationalGroupedDataset.scala | 152 +++++++++++++++++++++ .../apache/spark/sql/PlanGenerationTestSuite.scala | 14 ++ .../explain-results/groupby_agg.explain | 2 + .../resources/query-tests/queries/groupby_agg.json | 88 ++++++++++++ .../query-tests/queries/groupby_agg.proto.bin | 19 +++ 6 files changed, 298 insertions(+) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 977c823f7c7..c39fc6100f5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1035,6 +1035,29 @@ class Dataset[T] private[sql] (val session: SparkSession, private[sql] val plan: } } + /** + * Groups the Dataset using the specified columns, so we can run aggregation on them. See + * [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * ds.groupBy($"department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * ds.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 3.4.0 + */ + @scala.annotation.varargs + def groupBy(cols: Column*): RelationalGroupedDataset = { + new RelationalGroupedDataset(toDF(), cols.map(_.expr)) + } + /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala new file mode 100644 index 00000000000..a3dfcb01fdc --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.Locale + +import scala.collection.JavaConverters._ + +import org.apache.spark.connect.proto + +/** + * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], + * [[Dataset#cube cube]] or [[Dataset#rollup rollup]] (and also `pivot`). + * + * The main method is the `agg` function, which has multiple variants. This class also contains + * some first-order statistics such as `mean`, `sum` for convenience. + * + * @note + * This class was named `GroupedData` in Spark 1.x. + * + * @since 3.4.0 + */ +class RelationalGroupedDataset protected[sql] ( + private[sql] val df: DataFrame, + private[sql] val groupingExprs: Seq[proto.Expression]) { + + private[this] def toDF(aggExprs: Seq[proto.Expression]): DataFrame = { + // TODO: support other GroupByType such as Rollup, Cube, Pivot. + df.session.newDataset { builder => + builder.getAggregateBuilder + .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) + .setInput(df.plan.getRoot) + .addAllGroupingExpressions(groupingExprs.asJava) + .addAllAggregateExpressions(aggExprs.asJava) + } + } + + /** + * (Scala-specific) Compute aggregates by specifying the column names and aggregate methods. The + * resulting `DataFrame` will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg( + * "age" -> "max", + * "expense" -> "sum" + * ) + * }}} + * + * @since 3.4.0 + */ + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { + toDF((aggExpr +: aggExprs).map { case (colName, expr) => + strToExpr(expr, df(colName).expr) + }) + } + + /** + * (Scala-specific) Compute aggregates by specifying a map from column name to aggregate + * methods. The resulting `DataFrame` will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max", + * "expense" -> "sum" + * )) + * }}} + * + * @since 3.4.0 + */ + def agg(exprs: Map[String, String]): DataFrame = { + toDF(exprs.map { case (colName, expr) => + strToExpr(expr, df(colName).expr) + }.toSeq) + } + + /** + * (Java-specific) Compute aggregates by specifying a map from column name to aggregate methods. + * The resulting `DataFrame` will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * import com.google.common.collect.ImmutableMap; + * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); + * }}} + * + * @since 3.4.0 + */ + def agg(exprs: java.util.Map[String, String]): DataFrame = { + agg(exprs.asScala.toMap) + } + + private[this] def strToExpr(expr: String, inputExpr: proto.Expression): proto.Expression = { + val builder = proto.Expression.newBuilder() + + expr.toLowerCase(Locale.ROOT) match { + // We special handle a few cases that have alias that are not in function registry. + case "avg" | "average" | "mean" => + builder.getUnresolvedFunctionBuilder + .setFunctionName("avg") + .addArguments(inputExpr) + .setIsDistinct(false) + case "stddev" | "std" => + builder.getUnresolvedFunctionBuilder + .setFunctionName("stddev") + .addArguments(inputExpr) + .setIsDistinct(false) + // Also special handle count because we need to take care count(*). + case "count" | "size" => + // Turn count(*) into count(1) + inputExpr match { + case s if s.hasUnresolvedStar => + val exprBuilder = proto.Expression.newBuilder + exprBuilder.getLiteralBuilder.setInteger(1) + builder.getUnresolvedFunctionBuilder + .setFunctionName("count") + .addArguments(exprBuilder) + .setIsDistinct(false) + case _ => + builder.getUnresolvedFunctionBuilder + .setFunctionName("count") + .addArguments(inputExpr) + .setIsDistinct(false) + } + case name => + builder.getUnresolvedFunctionBuilder + .setFunctionName(name) + .addArguments(inputExpr) + .setIsDistinct(false) + } + builder.build() + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 7b5d8bd1018..8d4550dfe4f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -784,6 +784,20 @@ class PlanGenerationTestSuite extends ConnectFunSuite with BeforeAndAfterAll wit select(fn.max(Column("id"))) } + test("groupby agg") { + simple + .groupBy(Column("id")) + .agg( + "a" -> "max", + "b" -> "stddev", + "b" -> "std", + "b" -> "mean", + "b" -> "average", + "b" -> "avg", + "*" -> "size", + "a" -> "count") + } + test("function lit") { select( fn.lit(fn.col("id")), diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/groupby_agg.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/groupby_agg.explain new file mode 100644 index 00000000000..acb42c1408c --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/groupby_agg.explain @@ -0,0 +1,2 @@ +Aggregate [id#0L], [id#0L, max(a#0) AS max(a)#0, stddev(b#0) AS stddev(b)#0, stddev(b#0) AS stddev(b)#0, avg(b#0) AS avg(b)#0, avg(b#0) AS avg(b)#0, avg(b#0) AS avg(b)#0, count(1) AS count(1)#0L, count(a#0) AS count(a)#0L] ++- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json new file mode 100644 index 00000000000..7838a89974d --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json @@ -0,0 +1,88 @@ +{ + "aggregate": { + "input": { + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "groupType": "GROUP_TYPE_GROUPBY", + "groupingExpressions": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "id" + } + }], + "aggregateExpressions": [{ + "unresolvedFunction": { + "functionName": "max", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }] + } + }, { + "unresolvedFunction": { + "functionName": "stddev", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }, { + "unresolvedFunction": { + "functionName": "stddev", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }, { + "unresolvedFunction": { + "functionName": "avg", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }, { + "unresolvedFunction": { + "functionName": "avg", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }, { + "unresolvedFunction": { + "functionName": "avg", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }, { + "unresolvedFunction": { + "functionName": "count", + "arguments": [{ + "literal": { + "integer": 1 + } + }] + } + }, { + "unresolvedFunction": { + "functionName": "count", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin new file mode 100644 index 00000000000..9c6d1cca8a4 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin @@ -0,0 +1,19 @@ +J� +$Z" struct<id:bigint,a:int,b:double> +id" +max +a" +stddev +b" +stddev +b" +avg +b" +avg +b" +avg +b" +count +0" +count +a \ No newline at end of file --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org