Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22944#discussion_r231001655
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 ---
    @@ -262,25 +262,39 @@ object AppendColumns {
       def apply[T : Encoder, U : Encoder](
           func: T => U,
           child: LogicalPlan): AppendColumns = {
    +    val outputEncoder = encoderFor[U]
    +    val namedExpressions = if (!outputEncoder.isSerializedAsStruct) {
    +      assert(outputEncoder.namedExpressions.length == 1)
    +      outputEncoder.namedExpressions.map(Alias(_, "key")())
    +    } else {
    +      outputEncoder.namedExpressions
    +    }
         new AppendColumns(
           func.asInstanceOf[Any => Any],
           implicitly[Encoder[T]].clsTag.runtimeClass,
           implicitly[Encoder[T]].schema,
           UnresolvedDeserializer(encoderFor[T].deserializer),
    -      encoderFor[U].namedExpressions,
    +      namedExpressions,
           child)
       }
     
       def apply[T : Encoder, U : Encoder](
           func: T => U,
           inputAttributes: Seq[Attribute],
           child: LogicalPlan): AppendColumns = {
    +    val outputEncoder = encoderFor[U]
    +    val namedExpressions = if (!outputEncoder.isSerializedAsStruct) {
    +      assert(outputEncoder.namedExpressions.length == 1)
    +      outputEncoder.namedExpressions.map(Alias(_, "key")())
    +    } else {
    +      outputEncoder.namedExpressions
    --- End diff --
    
    For primitive type and product type, looks like it works:
    ```scala
    test("typed aggregation on primitive data") {
      val ds = Seq(1, 2, 3).toDS()
      val agg = ds.select(expr("value").as("data").as[Int])
        .groupByKey(_ >= 2)
        .agg(sum("data").as[Long], sum($"data" + 1).as[Long])
      agg.show()
    }
    ```
    ```
    +-----+---------+---------------+
    |value|sum(data)|sum((data + 1))|
    +-----+---------+---------------+
    |false|        1|              2|
    | true|        5|              7|
    +-----+---------+---------------+
    ```
    
    ```scala
    test("typed aggregation on product data") {
      val ds = Seq((1, 2), (2, 3), (3, 4)).toDS()
      val agg = ds.select(expr("_1").as("a").as[Int], 
expr("_2").as("b").as[Int])
        .groupByKey(_._1).agg(sum("a").as[Int], sum($"b" + 1).as[Int])
      agg.show
    }
    ```
    ```
    [info] - typed aggregation on primitive data (192 milliseconds)
    +-----+------+------------+
    |value|sum(a)|sum((b + 1))|
    +-----+------+------------+
    |    3|     3|           5|
    |    1|     1|           3|
    |    2|     2|           4|
    +-----+------+------------+
    
    ```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to