This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1134fae [SPARK-18299][SQL] Allow more aggregations on KeyValueGroupedDataset 1134fae is described below commit 1134faecf4fe2cc1bf7c3670f5f8f2b9d0c6f2e7 Author: nooberfsh <noober...@gmail.com> AuthorDate: Tue Jul 16 16:35:04 2019 -0700 [SPARK-18299][SQL] Allow more aggregations on KeyValueGroupedDataset ## What changes were proposed in this pull request? Add 4 additional agg to KeyValueGroupedDataset ## How was this patch tested? New test in DatasetSuite for typed aggregation Closes #24993 from nooberfsh/sqlagg. Authored-by: nooberfsh <noober...@gmail.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../apache/spark/sql/KeyValueGroupedDataset.scala | 65 ++++++++++++++++++++++ .../scala/org/apache/spark/sql/DatasetSuite.scala | 64 +++++++++++++++++++++ 2 files changed, 129 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index a3cbea9..0da52d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -521,6 +521,71 @@ class KeyValueGroupedDataset[K, V] private[sql]( aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 3.0.0 + */ + def agg[U1, U2, U3, U4, U5]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] = + aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 3.0.0 + */ + def agg[U1, U2, U3, U4, U5, U6]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] = + aggUntyped(col1, col2, col3, col4, col5, col6) + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 3.0.0 + */ + def agg[U1, U2, U3, U4, U5, U6, U7]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] = + aggUntyped(col1, col2, col3, col4, col5, col6, col7) + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + * + * @since 3.0.0 + */ + def agg[U1, U2, U3, U4, U5, U6, U7, U8]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7], + col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = + aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8) + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]] + + /** * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 4b08a4b..ff61431 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -603,6 +603,70 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 30L, 32L, 2L, 15.0), ("b", 3L, 5L, 2L, 1.5), ("c", 1L, 2L, 1L, 1.0)) } + test("typed aggregation: expr, expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1).agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double], + countDistinct("*").as[Long]), + ("a", 30L, 32L, 2L, 15.0, 2L), ("b", 3L, 5L, 2L, 1.5, 2L), ("c", 1L, 2L, 1L, 1.0, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1).agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double], + countDistinct("*").as[Long], + max("_2").as[Long]), + ("a", 30L, 32L, 2L, 15.0, 2L, 20L), + ("b", 3L, 5L, 2L, 1.5, 2L, 2L), + ("c", 1L, 2L, 1L, 1.0, 1L, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1).agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double], + countDistinct("*").as[Long], + max("_2").as[Long], + min("_2").as[Long]), + ("a", 30L, 32L, 2L, 15.0, 2L, 20L, 10L), + ("b", 3L, 5L, 2L, 1.5, 2L, 2L, 1L), + ("c", 1L, 2L, 1L, 1.0, 1L, 1L, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr, expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1).agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double], + countDistinct("*").as[Long], + max("_2").as[Long], + min("_2").as[Long], + mean("_2").as[Double]), + ("a", 30L, 32L, 2L, 15.0, 2L, 20L, 10L, 15.0), + ("b", 3L, 5L, 2L, 1.5, 2L, 2L, 1L, 1.5), + ("c", 1L, 2L, 1L, 1.0, 1L, 1L, 1L, 1.0)) + } + test("cogroup") { val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org