This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 21767d29b36 [SPARK-42529][CONNECT] Support Cube and Rollup in Scala client 21767d29b36 is described below commit 21767d29b36c3c8d812bb3ea8946a21a8ef6e65c Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Wed Feb 22 23:56:38 2023 -0400 [SPARK-42529][CONNECT] Support Cube and Rollup in Scala client ### What changes were proposed in this pull request? Support Cube and Rollup in Scala client. ### Why are the changes needed? API Coverage ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? UT Closes #40129 from amaliujia/support_cube_rollup_pivot. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 120 ++++++++++++++++++++- .../spark/sql/RelationalGroupedDataset.scala | 16 ++- .../apache/spark/sql/PlanGenerationTestSuite.scala | 16 +++ .../explain-results/cube_column.explain | 4 + .../explain-results/cube_string.explain | 4 + .../explain-results/rollup_column.explain | 4 + .../explain-results/rollup_string.explain | 4 + .../resources/query-tests/queries/cube_column.json | 34 ++++++ .../query-tests/queries/cube_column.proto.bin | 7 ++ .../resources/query-tests/queries/cube_string.json | 34 ++++++ .../query-tests/queries/cube_string.proto.bin | 7 ++ .../query-tests/queries/rollup_column.json | 34 ++++++ .../query-tests/queries/rollup_column.proto.bin | 7 ++ .../query-tests/queries/rollup_string.json | 34 ++++++ .../query-tests/queries/rollup_string.proto.bin | 7 ++ 15 files changed, 328 insertions(+), 4 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index c7ded04a963..560276d154e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1055,7 +1055,125 @@ class Dataset[T] private[sql] (val session: SparkSession, private[sql] val plan: */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { - new RelationalGroupedDataset(toDF(), cols.map(_.expr)) + new RelationalGroupedDataset( + toDF(), + cols.map(_.expr), + proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) + } + + /** + * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we + * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate + * functions. + * + * {{{ + * // Compute the average for all numeric columns rolled up by department and group. + * ds.rollup($"department", $"group").avg() + * + * // Compute the max age and average salary, rolled up by department and gender. + * ds.rollup($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 3.4.0 + */ + @scala.annotation.varargs + def rollup(cols: Column*): RelationalGroupedDataset = { + new RelationalGroupedDataset( + toDF(), + cols.map(_.expr), + proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) + } + + /** + * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we + * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate + * functions. + * + * This is a variant of rollup that can only group by existing columns using column names (i.e. + * cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns rolled up by department and group. + * ds.rollup("department", "group").avg() + * + * // Compute the max age and average salary, rolled up by department and gender. + * ds.rollup($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 3.4.0 + */ + @scala.annotation.varargs + def rollup(col1: String, cols: String*): RelationalGroupedDataset = { + val colNames: Seq[String] = col1 +: cols + new RelationalGroupedDataset( + toDF(), + colNames.map(colName => Column(colName).expr), + proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) + } + + /** + * Create a multi-dimensional cube for the current Dataset using the specified columns, so we + * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate + * functions. + * + * {{{ + * // Compute the average for all numeric columns cubed by department and group. + * ds.cube($"department", $"group").avg() + * + * // Compute the max age and average salary, cubed by department and gender. + * ds.cube($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 3.4.0 + */ + @scala.annotation.varargs + def cube(cols: Column*): RelationalGroupedDataset = { + new RelationalGroupedDataset( + toDF(), + cols.map(_.expr), + proto.Aggregate.GroupType.GROUP_TYPE_CUBE) + } + + /** + * Create a multi-dimensional cube for the current Dataset using the specified columns, so we + * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate + * functions. + * + * This is a variant of cube that can only group by existing columns using column names (i.e. + * cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns cubed by department and group. + * ds.cube("department", "group").avg() + * + * // Compute the max age and average salary, cubed by department and gender. + * ds.cube($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group untypedrel + * @since 3.4.0 + */ + @scala.annotation.varargs + def cube(col1: String, cols: String*): RelationalGroupedDataset = { + val colNames: Seq[String] = col1 +: cols + new RelationalGroupedDataset( + toDF(), + colNames.map(colName => Column(colName).expr), + proto.Aggregate.GroupType.GROUP_TYPE_CUBE) } /** diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index a6d3dc2e468..76db231db9e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -37,16 +37,26 @@ import org.apache.spark.connect.proto */ class RelationalGroupedDataset protected[sql] ( private[sql] val df: DataFrame, - private[sql] val groupingExprs: Seq[proto.Expression]) { + private[sql] val groupingExprs: Seq[proto.Expression], + groupType: proto.Aggregate.GroupType) { private[this] def toDF(aggExprs: Seq[Column]): DataFrame = { - // TODO: support other GroupByType such as Rollup, Cube, Pivot. df.session.newDataset { builder => builder.getAggregateBuilder - .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) .setInput(df.plan.getRoot) .addAllGroupingExpressions(groupingExprs.asJava) .addAllAggregateExpressions(aggExprs.map(e => e.expr).asJava) + + // TODO: support Pivot. + groupType match { + case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP => + builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) + case proto.Aggregate.GroupType.GROUP_TYPE_CUBE => + builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE) + case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY => + builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) + case g => throw new UnsupportedOperationException(g.toString) + } } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 42572f8427e..9ca91942567 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -1663,6 +1663,22 @@ class PlanGenerationTestSuite extends ConnectFunSuite with BeforeAndAfterAll wit .count() } + test("rollup column") { + simple.rollup(Column("a"), Column("b")).count() + } + + test("cube column") { + simple.cube(Column("a"), Column("b")).count() + } + + test("rollup string") { + simple.rollup("a", "b").count() + } + + test("cube string") { + simple.cube("a", "b").count() + } + test("function lit") { simple.select( fn.lit(fn.col("id")), diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/cube_column.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/cube_column.explain new file mode 100644 index 00000000000..1721162f478 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/cube_column.explain @@ -0,0 +1,4 @@ +Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, count(1) AS count#0L] ++- Expand [[id#0L, a#0, b#0, a#0, b#0, 0], [id#0L, a#0, b#0, a#0, null, 1], [id#0L, a#0, b#0, null, b#0, 2], [id#0L, a#0, b#0, null, null, 3]], [id#0L, a#0, b#0, a#0, b#0, spark_grouping_id#0L] + +- Project [id#0L, a#0, b#0, a#0 AS a#0, b#0 AS b#0] + +- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/cube_string.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/cube_string.explain new file mode 100644 index 00000000000..1721162f478 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/cube_string.explain @@ -0,0 +1,4 @@ +Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, count(1) AS count#0L] ++- Expand [[id#0L, a#0, b#0, a#0, b#0, 0], [id#0L, a#0, b#0, a#0, null, 1], [id#0L, a#0, b#0, null, b#0, 2], [id#0L, a#0, b#0, null, null, 3]], [id#0L, a#0, b#0, a#0, b#0, spark_grouping_id#0L] + +- Project [id#0L, a#0, b#0, a#0 AS a#0, b#0 AS b#0] + +- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/rollup_column.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/rollup_column.explain new file mode 100644 index 00000000000..c8f0f1e2aeb --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/rollup_column.explain @@ -0,0 +1,4 @@ +Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, count(1) AS count#0L] ++- Expand [[id#0L, a#0, b#0, a#0, b#0, 0], [id#0L, a#0, b#0, a#0, null, 1], [id#0L, a#0, b#0, null, null, 3]], [id#0L, a#0, b#0, a#0, b#0, spark_grouping_id#0L] + +- Project [id#0L, a#0, b#0, a#0 AS a#0, b#0 AS b#0] + +- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/rollup_string.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/rollup_string.explain new file mode 100644 index 00000000000..c8f0f1e2aeb --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/rollup_string.explain @@ -0,0 +1,4 @@ +Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, count(1) AS count#0L] ++- Expand [[id#0L, a#0, b#0, a#0, b#0, 0], [id#0L, a#0, b#0, a#0, null, 1], [id#0L, a#0, b#0, null, null, 3]], [id#0L, a#0, b#0, a#0, b#0, spark_grouping_id#0L] + +- Project [id#0L, a#0, b#0, a#0 AS a#0, b#0 AS b#0] + +- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/cube_column.json b/connector/connect/common/src/test/resources/query-tests/queries/cube_column.json new file mode 100644 index 00000000000..49016593a34 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/cube_column.json @@ -0,0 +1,34 @@ +{ + "aggregate": { + "input": { + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "groupType": "GROUP_TYPE_CUBE", + "groupingExpressions": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }], + "aggregateExpressions": [{ + "alias": { + "expr": { + "unresolvedFunction": { + "functionName": "count", + "arguments": [{ + "literal": { + "integer": 1 + } + }] + } + }, + "name": ["count"] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/cube_column.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/cube_column.proto.bin new file mode 100644 index 00000000000..c706144de59 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/cube_column.proto.bin @@ -0,0 +1,7 @@ +JR +$Z" struct<id:bigint,a:int,b:double> +a +b"2 + +count +0count \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/cube_string.json b/connector/connect/common/src/test/resources/query-tests/queries/cube_string.json new file mode 100644 index 00000000000..49016593a34 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/cube_string.json @@ -0,0 +1,34 @@ +{ + "aggregate": { + "input": { + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "groupType": "GROUP_TYPE_CUBE", + "groupingExpressions": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }], + "aggregateExpressions": [{ + "alias": { + "expr": { + "unresolvedFunction": { + "functionName": "count", + "arguments": [{ + "literal": { + "integer": 1 + } + }] + } + }, + "name": ["count"] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin new file mode 100644 index 00000000000..c706144de59 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin @@ -0,0 +1,7 @@ +JR +$Z" struct<id:bigint,a:int,b:double> +a +b"2 + +count +0count \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/rollup_column.json b/connector/connect/common/src/test/resources/query-tests/queries/rollup_column.json new file mode 100644 index 00000000000..f976e4ea10f --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/rollup_column.json @@ -0,0 +1,34 @@ +{ + "aggregate": { + "input": { + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "groupType": "GROUP_TYPE_ROLLUP", + "groupingExpressions": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }], + "aggregateExpressions": [{ + "alias": { + "expr": { + "unresolvedFunction": { + "functionName": "count", + "arguments": [{ + "literal": { + "integer": 1 + } + }] + } + }, + "name": ["count"] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/rollup_column.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/rollup_column.proto.bin new file mode 100644 index 00000000000..89ef8ff947b --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/rollup_column.proto.bin @@ -0,0 +1,7 @@ +JR +$Z" struct<id:bigint,a:int,b:double> +a +b"2 + +count +0count \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/rollup_string.json b/connector/connect/common/src/test/resources/query-tests/queries/rollup_string.json new file mode 100644 index 00000000000..f976e4ea10f --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/rollup_string.json @@ -0,0 +1,34 @@ +{ + "aggregate": { + "input": { + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "groupType": "GROUP_TYPE_ROLLUP", + "groupingExpressions": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }], + "aggregateExpressions": [{ + "alias": { + "expr": { + "unresolvedFunction": { + "functionName": "count", + "arguments": [{ + "literal": { + "integer": 1 + } + }] + } + }, + "name": ["count"] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin new file mode 100644 index 00000000000..89ef8ff947b --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin @@ -0,0 +1,7 @@ +JR +$Z" struct<id:bigint,a:int,b:double> +a +b"2 + +count +0count \ No newline at end of file --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org