Repository: spark Updated Branches: refs/heads/master fb0894b3a -> 84b245f2d
[SPARK-15780][SQL] Support mapValues on KeyValueGroupedDataset ## What changes were proposed in this pull request? Add mapValues to KeyValueGroupedDataset ## How was this patch tested? New test in DatasetSuite for groupBy function, mapValues, flatMap Author: Koert Kuipers <ko...@tresata.com> Closes #13526 from koertkuipers/feat-keyvaluegroupeddataset-mapvalues. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/84b245f2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/84b245f2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/84b245f2 Branch: refs/heads/master Commit: 84b245f2dd31c1cebbf12458bf11f67e287e93f4 Parents: fb0894b Author: Koert Kuipers <ko...@tresata.com> Authored: Thu Oct 20 10:08:12 2016 -0700 Committer: Reynold Xin <r...@databricks.com> Committed: Thu Oct 20 10:08:12 2016 -0700 ---------------------------------------------------------------------- .../sql/catalyst/plans/logical/object.scala | 13 ++++++ .../spark/sql/KeyValueGroupedDataset.scala | 42 ++++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 11 +++++ 3 files changed, 66 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/84b245f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index fefe5a3..0ab4c90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -230,6 +230,19 @@ object AppendColumns { encoderFor[U].namedExpressions, child) } + + def apply[T : Encoder, U : Encoder]( + func: T => U, + inputAttributes: Seq[Attribute], + child: LogicalPlan): AppendColumns = { + new AppendColumns( + func.asInstanceOf[Any => Any], + implicitly[Encoder[T]].clsTag.runtimeClass, + implicitly[Encoder[T]].schema, + UnresolvedDeserializer(encoderFor[T].deserializer, inputAttributes), + encoderFor[U].namedExpressions, + child) + } } /** http://git-wip-us.apache.org/repos/asf/spark/blob/84b245f2/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 828eb94..4cb0313 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -67,6 +67,48 @@ class KeyValueGroupedDataset[K, V] private[sql]( groupingAttributes) /** + * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied + * to the data. The grouping key is unchanged by this. + * + * {{{ + * // Create values grouped by key from a Dataset[(K, V)] + * ds.groupByKey(_._1).mapValues(_._2) // Scala + * }}} + * + * @since 2.1.0 + */ + def mapValues[W : Encoder](func: V => W): KeyValueGroupedDataset[K, W] = { + val withNewData = AppendColumns(func, dataAttributes, logicalPlan) + val projected = Project(withNewData.newColumns ++ groupingAttributes, withNewData) + val executed = sparkSession.sessionState.executePlan(projected) + + new KeyValueGroupedDataset( + encoderFor[K], + encoderFor[W], + executed, + withNewData.newColumns, + groupingAttributes) + } + + /** + * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied + * to the data. The grouping key is unchanged by this. + * + * {{{ + * // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>> + * Dataset<Tuple2<String, Integer>> ds = ...; + * KeyValueGroupedDataset<String, Integer> grouped = + * ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT()); // Java 8 + * }}} + * + * @since 2.1.0 + */ + def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = { + implicit val uEnc = encoder + mapValues { (v: V) => func.call(v) } + } + + /** * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping * over the Dataset to extract the keys and then running a distinct operation on those. * http://git-wip-us.apache.org/repos/asf/spark/blob/84b245f2/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5fce9b4..cc367ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -336,6 +336,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { "a", "30", "b", "3", "c", "1") } + test("groupBy function, mapValues, flatMap") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val keyValue = ds.groupByKey(_._1).mapValues(_._2) + val agged = keyValue.mapGroups { case (g, iter) => (g, iter.sum) } + checkDataset(agged, ("a", 30), ("b", 3), ("c", 1)) + + val keyValue1 = ds.groupByKey(t => (t._1, "key")).mapValues(t => (t._2, "value")) + val agged1 = keyValue1.mapGroups { case (g, iter) => (g._1, iter.map(_._1).sum) } + checkDataset(agged, ("a", 30), ("b", 3), ("c", 1)) + } + test("groupBy function, reduce") { val ds = Seq("abc", "xyz", "hello").toDS() val agged = ds.groupByKey(_.length).reduceGroups(_ + _) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org