hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1177017870


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -0,0 +1,416 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Arrays
+
+import scala.collection.JavaConverters._
+import scala.language.existentials
+
+import org.apache.spark.api.java.function._
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.connect.client.UdfUtils
+import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
+import org.apache.spark.sql.functions.col
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key. 
Users should not
+ * construct a [[KeyValueGroupedDataset]] directly, but should instead call 
`groupByKey` on an
+ * existing [[Dataset]].
+ *
+ * @since 3.5.0
+ */
+abstract class KeyValueGroupedDataset[K, V] private[sql] () extends 
Serializable {
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the type of the key has 
been mapped to the
+   * specified type. The mapping of key columns to the type follows the same 
rules as `as` on
+   * [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` 
has been applied to
+   * the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create values grouped by key from a Dataset[(K, V)]
+   *   ds.groupByKey(_._1).mapValues(_._2) // Scala
+   * }}}
+   *
+   * @since 3.5.0
+   */
+  def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = 
{
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` 
has been applied to
+   * the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create Integer values grouped by String key from a 
Dataset<Tuple2<String, Integer>>
+   *   Dataset<Tuple2<String, Integer>> ds = ...;
+   *   KeyValueGroupedDataset<String, Integer> grouped =
+   *     ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, 
Encoders.INT());
+   * }}}
+   *
+   * @since 3.5.0
+   */
+  def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): 
KeyValueGroupedDataset[K, W] = {
+    mapValues(UdfUtils.mapFunctionToScalaFunc(func))(encoder)
+  }
+
+  /**
+   * Returns a [[Dataset]] that contains each unique key. This is equivalent 
to doing mapping over
+   * the Dataset to extract the keys and then running a distinct operation on 
those.
+   *
+   * @since 3.5.0
+   */
+  def keys: Dataset[K] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For 
each unique group, the
+   * function will be passed the group key and an iterator that contains all 
of the elements in
+   * the group. The function can return an iterator containing elements of an 
arbitrary type which
+   * will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result 
requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an 
aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is 
too large to fit into
+   * memory. However, users must take care to avoid materializing the whole 
iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is 
possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): 
Dataset[U] = {
+    flatMapSortedGroups()(f)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For 
each unique group, the
+   * function will be passed the group key and an iterator that contains all 
of the elements in
+   * the group. The function can return an iterator containing elements of an 
arbitrary type which
+   * will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result 
requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an 
aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is 
too large to fit into
+   * memory. However, users must take care to avoid materializing the whole 
iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is 
possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: 
Encoder[U]): Dataset[U] = {
+    flatMapGroups(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For 
each unique group, the
+   * function will be passed the group key and a sorted iterator that contains 
all of the elements
+   * in the group. The function can return an iterator containing elements of 
an arbitrary type
+   * which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result 
requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an 
aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is 
too large to fit into
+   * memory. However, users must take care to avoid materializing the whole 
iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is 
possible given the
+   * memory constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except 
for the iterator to be
+   * sorted according to the given sort expressions. That sorting does not add 
computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.5.0
+   */
+  def flatMapSortedGroups[U: Encoder](sortExprs: Column*)(
+      f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For 
each unique group, the
+   * function will be passed the group key and a sorted iterator that contains 
all of the elements
+   * in the group. The function can return an iterator containing elements of 
an arbitrary type
+   * which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result 
requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an 
aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is 
too large to fit into
+   * memory. However, users must take care to avoid materializing the whole 
iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is 
possible given the
+   * memory constraints of their cluster.
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except 
for the iterator to be
+   * sorted according to the given sort expressions. That sorting does not add 
computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
+   * @since 3.5.0
+   */
+  def flatMapSortedGroups[U](
+      SortExprs: Array[Column],
+      f: FlatMapGroupsFunction[K, V, U],
+      encoder: Encoder[U]): Dataset[U] = {
+    flatMapSortedGroups(SortExprs: 
_*)(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data. For 
each unique group, the
+   * function will be passed the group key and an iterator that contains all 
of the elements in
+   * the group. The function can return an element of arbitrary type which 
will be returned as a
+   * new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result 
requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an 
aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is 
too large to fit into
+   * memory. However, users must take care to avoid materializing the whole 
iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is 
possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
+    flatMapGroups(UdfUtils.mapGroupsFuncToFlatMapAdaptor(f))
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data. For 
each unique group, the
+   * function will be passed the group key and an iterator that contains all 
of the elements in
+   * the group. The function can return an element of arbitrary type which 
will be returned as a
+   * new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result 
requires shuffling all
+   * the data in the [[Dataset]]. If an application intends to perform an 
aggregation over each
+   * key, it is best to use the reduce function or an
+   * `org.apache.spark.sql.expressions#Aggregator`.
+   *
+   * Internally, the implementation will spill to disk if any given group is 
too large to fit into
+   * memory. However, users must take care to avoid materializing the whole 
iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is 
possible given the
+   * memory constraints of their cluster.
+   *
+   * @since 3.5.0
+   */
+  def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): 
Dataset[U] = {
+    mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each cogrouped data. For 
each unique group,
+   * the function will be passed the grouping key and 2 iterators containing 
all elements in the
+   * group from [[Dataset]] `this` and `other`. The function can return an 
iterator containing
+   * elements of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    cogroupSorted(other)()()(f)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each cogrouped data. For 
each unique group, the
+   * function will be passed the grouping key and 2 iterators containing all 
elements in the group
+   * from [[Dataset]] `this` and `other`. The function can return an iterator 
containing elements
+   * of an arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 3.5.0
+   */
+  def cogroup[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroup(other)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each sorted cogrouped 
data. For each unique
+   * group, the function will be passed the grouping key and 2 sorted 
iterators containing all
+   * elements in the group from [[Dataset]] `this` and `other`. The function 
can return an
+   * iterator containing elements of an arbitrary type which will be returned 
as a new
+   * [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the 
iterators to be
+   * sorted according to the given sort expressions. That sorting does not add 
computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.5.0
+   */
+  def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, 
U])(thisSortExprs: Column*)(
+      otherSortExprs: Column*)(
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each sorted cogrouped data. 
For each unique
+   * group, the function will be passed the grouping key and 2 sorted 
iterators containing all
+   * elements in the group from [[Dataset]] `this` and `other`. The function 
can return an
+   * iterator containing elements of an arbitrary type which will be returned 
as a new
+   * [[Dataset]].
+   *
+   * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the 
iterators to be
+   * sorted according to the given sort expressions. That sorting does not add 
computational
+   * complexity.
+   *
+   * @see
+   *   [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
+   * @since 3.5.0
+   */
+  def cogroupSorted[U, R](
+      other: KeyValueGroupedDataset[K, U],
+      thisSortExprs: Array[Column],
+      otherSortExprs: Array[Column],
+      f: CoGroupFunction[K, V, U, R],
+      encoder: Encoder[R]): Dataset[R] = {
+    cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)(
+      UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
+  }
+}
+
+/**
+ * This class is the implementation of class [[KeyValueGroupedDataset]]. This 
class memorizes the
+ * initial types of the grouping function so that the original function will 
be sent to the server
+ * to perform the grouping first. Then any type modifications on the keys and 
the values will be
+ * applied sequentially to ensure the final type of the result remains the 
same as how
+ * [[KeyValueGroupedDataset]] behaves on the server.
+ */
+private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
+    private val ds: Dataset[IV],
+    private val sparkSession: SparkSession,
+    private val plan: proto.Plan,
+    private val ikEncoder: AgnosticEncoder[IK],
+    private val kEncoder: AgnosticEncoder[K],
+    private val groupingFunc: IV => IK,
+    private val valueMapFunc: IV => V)
+    extends KeyValueGroupedDataset[K, V] {
+
+  private val ivEncoder = ds.encoder
+
+  override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {

Review Comment:
   We had a discussion about this one. `keyAs` just changes the key type. It 
does change the cardinality of the result. The cardinality is determined by the 
result of the grouping function.
   
   For example the following example returns 4 groups instead of 2:
   ```scala
   case class K1(a: Long)
   case class K2(a: Long, b: Long)
   spark.range(10).as[Long].groupByKey(id => K2(id % 2, id % 
4)).keyAs[K1].count().collect()
   ```
   
   @cloud-fan is this expected?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to