Refined more: I just got rid of wrapping fields into struct, but the type
of result for UDAF is still struct. I need to extract the fields one by
one, but I guess I just haven't find a function which does the thing.

I crafted this code without IDE and ran from spark-shell, so there should
be many spots you can make it shorter or clean up.

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import java.sql.Date

class MaxRow extends UserDefinedAggregateFunction {
  // This is the input fields for your aggregate function.
  override def inputSchema: org.apache.spark.sql.types.StructType =
        StructField("AMOUNT", IntegerType, true),
        StructField("MY_TIMESTAMP", DateType, true))

  // This is the internal fields you keep for computing your aggregate.
  override def bufferSchema: StructType =
        StructField("AMOUNT", IntegerType, true),
        StructField("MY_TIMESTAMP", DateType, true))

  // This is the output type of your aggregatation function.
  override def dataType: DataType =
  new StructType().add("st", StructType(Seq(
        StructField("AMOUNT", IntegerType, true),
        StructField("MY_TIMESTAMP", DateType, true))

  override def deterministic: Boolean = true

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {

  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit =
    if (buffer.getAs[Any](0) == null || buffer.getInt(0) < input.getInt(0))
      buffer(0) = input(0)
      buffer(1) = input(1)

  // This is how to merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
= {
    if (buffer1.getAs[Any](0) == null || (buffer2.getAs[Any](0) != null &&
buffer1.getInt(0) < buffer2.getInt(0))) {
      buffer1(0) = buffer2(0)
      buffer1(1) = buffer2(1)

  // This is where you output the final value, given the final value of
your bufferSchema.
  override def evaluate(buffer: Row): Any = {
    Row(Row(buffer(0), buffer(1)))

val maxrow = new MaxRow
spark.udf.register("maxrow", maxrow)

import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
import org.apache.spark.sql.types._
import spark.implicits._

val socketDF = spark
  .option("host", "localhost")
  .option("port", 9999)

val schema = StructType(Seq(
  StructField("ID", IntegerType, true),
  StructField("AMOUNT", IntegerType, true),
  StructField("MY_TIMESTAMP", DateType, true)

val query = socketDF
  .selectExpr("CAST(value AS STRING) as value")
  .select(from_json($"value", schema=schema).as("data"))
  .select($"data.ID", $"data.AMOUNT", $"data.MY_TIMESTAMP")
  .agg(maxrow(col("AMOUNT"), col("MY_TIMESTAMP")).as("maxrow"))
  .selectExpr("ID", "", "")
  .trigger(Trigger.ProcessingTime("1 seconds"))

- Jungtaek Lim (HeartSaVioR)

2018년 4월 18일 (수) 오전 7:41, Jungtaek Lim <>님이 작성:

> I think I missed something: self-join is not needed via defining UDAF and
> using it from aggregation. Since it requires all fields to be accessed, I
> can't find any other approach than wrap fields into struct and unwrap
> afterwards. There doesn't look like way to pass multiple fields in UDAF, at
> least in RelationalGroupedDataset.
> Here's the working code which runs fine in console:
> ----
> import org.apache.spark.sql.expressions.MutableAggregationBuffer
> import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
> import org.apache.spark.sql.Row
> import org.apache.spark.sql.types._
> import java.sql.Date
> class MaxRow extends UserDefinedAggregateFunction {
>   // This is the input fields for your aggregate function.
>   override def inputSchema: org.apache.spark.sql.types.StructType =
>     new StructType().add("st", StructType(Seq(
>         StructField("AMOUNT", IntegerType, true),
>         StructField("MY_TIMESTAMP", DateType, true))
>         )
>       )
>   // This is the internal fields you keep for computing your aggregate.
>   override def bufferSchema: StructType =
>   new StructType().add("st", StructType(Seq(
>         StructField("AMOUNT", IntegerType, true),
>         StructField("MY_TIMESTAMP", DateType, true))
>         )
>       )
>   // This is the output type of your aggregatation function.
>   override def dataType: DataType =
>   new StructType().add("st", StructType(Seq(
>         StructField("AMOUNT", IntegerType, true),
>         StructField("MY_TIMESTAMP", DateType, true))
>         )
>       )
>   override def deterministic: Boolean = true
>   // This is the initial value for your buffer schema.
>   override def initialize(buffer: MutableAggregationBuffer): Unit = {
>   }
>   // This is how to update your buffer schema given an input.
>   override def update(buffer: MutableAggregationBuffer, input: Row): Unit
> = {
>     val inputRowStruct = input.getAs[Row](0)
>     if (buffer.getAs[Row](0) == null || buffer.getAs[Row](0).getInt(0) <
> input.getAs[Row](0).getInt(0)) {
>       buffer(0) = inputRowStruct
>     }
>   }
>   // This is how to merge two objects with the bufferSchema type.
>   override def merge(buffer1: MutableAggregationBuffer, buffer2: Row):
> Unit = {
>     if (buffer1.getAs[Row](0) == null || (buffer2.getAs[Row](0) != null &&
> buffer1.getAs[Row](0).getInt(0) < buffer2.getAs[Row](0).getInt(0))) {
>       buffer1(0) = buffer2(0)
>     }
>   }
>   // This is where you output the final value, given the final value of
> your bufferSchema.
>   override def evaluate(buffer: Row): Any = {
>     buffer
>   }
> }
> spark.udf.register("maxrow", new MaxRow)
> import org.apache.spark.sql.functions._
> import org.apache.spark.sql.streaming.{OutputMode, Trigger}
> import org.apache.spark.sql.types._
> import spark.implicits._
> val socketDF = spark
>   .readStream
>   .format("socket")
>   .option("host", "localhost")
>   .option("port", 9999)
>   .load()
> val schema = StructType(Seq(
>   StructField("ID", IntegerType, true),
>   StructField("AMOUNT", IntegerType, true),
>   StructField("MY_TIMESTAMP", DateType, true)
> ))
> val query = socketDF
>   .selectExpr("CAST(value AS STRING) as value")
>   .as[String]
>   .select(from_json($"value", schema=schema).as("data"))
>   .selectExpr("data.ID as ID", "struct(data.AMOUNT, data.MY_TIMESTAMP) as
> structure")
>   .groupBy($"ID")
>   .agg("structure" -> "maxrow")
>   .selectExpr("ID", "`maxrow(structure)`.st.AMOUNT",
> "`maxrow(structure)`.st.MY_TIMESTAMP")
>   .writeStream
>   .format("console")
>   .trigger(Trigger.ProcessingTime("1 seconds"))
>   .outputMode(OutputMode.Update())
>   .start()
> ----
> You still want to group records by event-time window and watermark: even
> putting all five records together to the socket (by nc), two micro-batches
> were handling the records and provide two results.
> -------------------------------------------
> Batch: 0
> -------------------------------------------
> +---+------+------------+
> +---+------+------------+
> |  1|    10|  2018-04-01|
> |  2|    30|  2018-04-01|
> +---+------+------------+
> -------------------------------------------
> Batch: 1
> -------------------------------------------
> +---+------+------------+
> +---+------+------------+
> |  2|    40|  2018-04-01|
> +---+------+------------+
> - Jungtaek Lim (HeartSaVioR)
> 2018년 4월 18일 (수) 오전 5:56, Jungtaek Lim <>님이 작성:
>> That might be simple if you want to get aggregated values for both amount
>> and my_timestamp:
>>     val schema = StructType(Seq(
>>       StructField("ID", IntegerType, true),
>>       StructField("AMOUNT", IntegerType, true),
>>       StructField("MY_TIMESTAMP", DateType, true)
>>     ))
>>     val query = socketDF
>>       .selectExpr("CAST(value AS STRING) as value")
>>       .as[String]
>>       .select(from_json($"value", schema=schema).as("data"))
>>       .select($"data.ID", $"data.AMOUNT", $"data.MY_TIMESTAMP")
>>       .groupBy($"ID")
>>       .agg("AMOUNT" -> "max", "MY_TIMESTAMP" -> "max")
>> which requires you to set output mode as Update mode or Complete mode.
>> But I guess you would like to select the max row and use MY_TIMESTAMP
>> from max row, then I guess you need to do inner self-join, like below:
>>     val query = socketDF
>>       .selectExpr("CAST(value AS STRING) as value")
>>       .as[String]
>>       .select(from_json($"value", schema=schema).as("data"))
>>       .select($"data.ID", $"data.AMOUNT")
>>       .groupBy($"ID")
>>       .agg("AMOUNT" -> "max")
>>     val query2 = socketDF
>>       .selectExpr("CAST(value AS STRING) as value")
>>       .as[String]
>>       .select(from_json($"value", schema=schema).as("data"))
>>       .select($"data.ID".as("SELF_ID"), $"data.AMOUNT".as("SELF_AMOUNT"),
>> $"data.MY_TIMESTAMP")
>>     val query3 = query.join(query2, expr("""
>>        ID = ID AND
>>     """))
>> which is NOT valid at least for Spark 2.3, because aggregation requires
>> Update/Complete mode but join requires Append mode.
>> (Guide page of structured streaming clearly explains such limitation:
>> "Cannot use streaming aggregation before joins.")
>> If you can achieve with mapGroupWithState, you may want to stick with
>> that.
>> Btw, when you deal with streaming, you may want to define logical batch
>> for all aggregations and joins via defining window and watermark. You
>> wouldn't want to get different result according to the micro-batch, and
>> then you always want to deal with event time window.
>> Thanks,
>> Jungtaek Lim (HeartSaVioR)
>> 2018년 4월 18일 (수) 오전 3:42, kant kodali <>님이 작성:
>>> Hi TD,
>>> Thanks for that. The only reason I ask is I don't see any alternative
>>> solution to solve the problem below using raw sql.
>>> How to select the max row for every group in spark structured streaming
>>> 2.3.0 without using order by since it requires complete mode or
>>> mapGroupWithState?
>>> *Input:*
>>> id | amount     | my_timestamp
>>> -------------------------------------------
>>> 1  |      5     |  2018-04-01T01:00:00.000Z
>>> 1  |     10     |  2018-04-01T01:10:00.000Z
>>> 2  |     20     |  2018-04-01T01:20:00.000Z
>>> 2  |     30     |  2018-04-01T01:25:00.000Z
>>> 2  |     40     |  2018-04-01T01:30:00.000Z
>>> *Expected Output:*
>>> id | amount     | my_timestamp
>>> -------------------------------------------
>>> 1  |     10     |  2018-04-01T01:10:00.000Z
>>> 2  |     40     |  2018-04-01T01:30:00.000Z
>>> Looking for a streaming solution using either raw sql like 
>>> sparkSession.sql("sql
>>> query") or similar to raw sql but not something like mapGroupWithState
>>> On Mon, Apr 16, 2018 at 8:32 PM, Tathagata Das <
>>>> wrote:
>>>> Unfortunately no. Honestly it does not make sense as for type-aware
>>>> operations like map, mapGroups, etc., you have to provide an actual JVM
>>>> function. That does not fit in with the SQL language structure.
>>>> On Mon, Apr 16, 2018 at 7:34 PM, kant kodali <>
>>>> wrote:
>>>>> Hi All,
>>>>> can we use mapGroupsWithState in raw SQL? or is it in the roadmap?
>>>>> Thanks!

