Refined more: I just got rid of wrapping fields into struct, but the type
of result for UDAF is still struct. I need to extract the fields one by
one, but I guess I just haven't find a function which does the thing.
I crafted this code without IDE and ran from spark-shell, so there should
be many spots you can make it shorter or clean up.
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import java.sql.Date
class MaxRow extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(Seq(
StructField("AMOUNT", IntegerType, true),
StructField("MY_TIMESTAMP", DateType, true))
)
// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType =
StructType(Seq(
StructField("AMOUNT", IntegerType, true),
StructField("MY_TIMESTAMP", DateType, true))
)
// This is the output type of your aggregatation function.
override def dataType: DataType =
new StructType().add("st", StructType(Seq(
StructField("AMOUNT", IntegerType, true),
StructField("MY_TIMESTAMP", DateType, true))
)
)
override def deterministic: Boolean = true
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit =
{
if (buffer.getAs[Any](0) == null || buffer.getInt(0) < input.getInt(0))
{
buffer(0) = input(0)
buffer(1) = input(1)
}
}
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
= {
if (buffer1.getAs[Any](0) == null || (buffer2.getAs[Any](0) != null &&
buffer1.getInt(0) < buffer2.getInt(0))) {
buffer1(0) = buffer2(0)
buffer1(1) = buffer2(1)
}
}
// This is where you output the final value, given the final value of
your bufferSchema.
override def evaluate(buffer: Row): Any = {
Row(Row(buffer(0), buffer(1)))
}
}
val maxrow = new MaxRow
spark.udf.register("maxrow", maxrow)
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
import org.apache.spark.sql.types._
import spark.implicits._
val socketDF = spark
.readStream
.format("socket")
.option("host", "localhost")
.option("port", 9999)
.load()
val schema = StructType(Seq(
StructField("ID", IntegerType, true),
StructField("AMOUNT", IntegerType, true),
StructField("MY_TIMESTAMP", DateType, true)
))
val query = socketDF
.selectExpr("CAST(value AS STRING) as value")
.as[String]
.select(from_json($"value", schema=schema).as("data"))
.select($"data.ID", $"data.AMOUNT", $"data.MY_TIMESTAMP")
.groupBy($"ID")
.agg(maxrow(col("AMOUNT"), col("MY_TIMESTAMP")).as("maxrow"))
.selectExpr("ID", "maxrow.st.AMOUNT", "maxrow.st.MY_TIMESTAMP")
.writeStream
.format("console")
.trigger(Trigger.ProcessingTime("1 seconds"))
.outputMode(OutputMode.Update())
.start()
- Jungtaek Lim (HeartSaVioR)
2018년 4월 18일 (수) 오전 7:41, Jungtaek Lim <[email protected]>님이 작성:
> I think I missed something: self-join is not needed via defining UDAF and
> using it from aggregation. Since it requires all fields to be accessed, I
> can't find any other approach than wrap fields into struct and unwrap
> afterwards. There doesn't look like way to pass multiple fields in UDAF, at
> least in RelationalGroupedDataset.
>
> Here's the working code which runs fine in console:
>
> ----
> import org.apache.spark.sql.expressions.MutableAggregationBuffer
> import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
> import org.apache.spark.sql.Row
> import org.apache.spark.sql.types._
> import java.sql.Date
>
> class MaxRow extends UserDefinedAggregateFunction {
> // This is the input fields for your aggregate function.
> override def inputSchema: org.apache.spark.sql.types.StructType =
> new StructType().add("st", StructType(Seq(
> StructField("AMOUNT", IntegerType, true),
> StructField("MY_TIMESTAMP", DateType, true))
> )
> )
>
> // This is the internal fields you keep for computing your aggregate.
> override def bufferSchema: StructType =
> new StructType().add("st", StructType(Seq(
> StructField("AMOUNT", IntegerType, true),
> StructField("MY_TIMESTAMP", DateType, true))
> )
> )
>
> // This is the output type of your aggregatation function.
> override def dataType: DataType =
> new StructType().add("st", StructType(Seq(
> StructField("AMOUNT", IntegerType, true),
> StructField("MY_TIMESTAMP", DateType, true))
> )
> )
>
> override def deterministic: Boolean = true
>
> // This is the initial value for your buffer schema.
> override def initialize(buffer: MutableAggregationBuffer): Unit = {
> }
>
> // This is how to update your buffer schema given an input.
> override def update(buffer: MutableAggregationBuffer, input: Row): Unit
> = {
> val inputRowStruct = input.getAs[Row](0)
> if (buffer.getAs[Row](0) == null || buffer.getAs[Row](0).getInt(0) <
> input.getAs[Row](0).getInt(0)) {
> buffer(0) = inputRowStruct
> }
> }
>
> // This is how to merge two objects with the bufferSchema type.
> override def merge(buffer1: MutableAggregationBuffer, buffer2: Row):
> Unit = {
> if (buffer1.getAs[Row](0) == null || (buffer2.getAs[Row](0) != null &&
> buffer1.getAs[Row](0).getInt(0) < buffer2.getAs[Row](0).getInt(0))) {
> buffer1(0) = buffer2(0)
> }
> }
>
> // This is where you output the final value, given the final value of
> your bufferSchema.
> override def evaluate(buffer: Row): Any = {
> buffer
> }
> }
>
> spark.udf.register("maxrow", new MaxRow)
>
> import org.apache.spark.sql.functions._
> import org.apache.spark.sql.streaming.{OutputMode, Trigger}
> import org.apache.spark.sql.types._
> import spark.implicits._
>
> val socketDF = spark
> .readStream
> .format("socket")
> .option("host", "localhost")
> .option("port", 9999)
> .load()
>
> val schema = StructType(Seq(
> StructField("ID", IntegerType, true),
> StructField("AMOUNT", IntegerType, true),
> StructField("MY_TIMESTAMP", DateType, true)
> ))
>
> val query = socketDF
> .selectExpr("CAST(value AS STRING) as value")
> .as[String]
> .select(from_json($"value", schema=schema).as("data"))
> .selectExpr("data.ID as ID", "struct(data.AMOUNT, data.MY_TIMESTAMP) as
> structure")
> .groupBy($"ID")
> .agg("structure" -> "maxrow")
> .selectExpr("ID", "`maxrow(structure)`.st.AMOUNT",
> "`maxrow(structure)`.st.MY_TIMESTAMP")
> .writeStream
> .format("console")
> .trigger(Trigger.ProcessingTime("1 seconds"))
> .outputMode(OutputMode.Update())
> .start()
> ----
>
> You still want to group records by event-time window and watermark: even
> putting all five records together to the socket (by nc), two micro-batches
> were handling the records and provide two results.
>
> -------------------------------------------
> Batch: 0
> -------------------------------------------
> +---+------+------------+
> | ID|AMOUNT|MY_TIMESTAMP|
> +---+------+------------+
> | 1| 10| 2018-04-01|
> | 2| 30| 2018-04-01|
> +---+------+------------+
> -------------------------------------------
> Batch: 1
> -------------------------------------------
> +---+------+------------+
> | ID|AMOUNT|MY_TIMESTAMP|
> +---+------+------------+
> | 2| 40| 2018-04-01|
> +---+------+------------+
>
> - Jungtaek Lim (HeartSaVioR)
>
> 2018년 4월 18일 (수) 오전 5:56, Jungtaek Lim <[email protected]>님이 작성:
>
>> That might be simple if you want to get aggregated values for both amount
>> and my_timestamp:
>>
>> val schema = StructType(Seq(
>> StructField("ID", IntegerType, true),
>> StructField("AMOUNT", IntegerType, true),
>> StructField("MY_TIMESTAMP", DateType, true)
>> ))
>>
>> val query = socketDF
>> .selectExpr("CAST(value AS STRING) as value")
>> .as[String]
>> .select(from_json($"value", schema=schema).as("data"))
>> .select($"data.ID", $"data.AMOUNT", $"data.MY_TIMESTAMP")
>> .groupBy($"ID")
>> .agg("AMOUNT" -> "max", "MY_TIMESTAMP" -> "max")
>>
>> which requires you to set output mode as Update mode or Complete mode.
>>
>> But I guess you would like to select the max row and use MY_TIMESTAMP
>> from max row, then I guess you need to do inner self-join, like below:
>>
>> val query = socketDF
>> .selectExpr("CAST(value AS STRING) as value")
>> .as[String]
>> .select(from_json($"value", schema=schema).as("data"))
>> .select($"data.ID", $"data.AMOUNT")
>> .groupBy($"ID")
>> .agg("AMOUNT" -> "max")
>>
>> val query2 = socketDF
>> .selectExpr("CAST(value AS STRING) as value")
>> .as[String]
>> .select(from_json($"value", schema=schema).as("data"))
>> .select($"data.ID".as("SELF_ID"), $"data.AMOUNT".as("SELF_AMOUNT"),
>> $"data.MY_TIMESTAMP")
>>
>> val query3 = query.join(query2, expr("""
>> ID = ID AND
>> `MAX(AMOUNT)` = SELF_AMOUNT
>> """))
>>
>> which is NOT valid at least for Spark 2.3, because aggregation requires
>> Update/Complete mode but join requires Append mode.
>> (Guide page of structured streaming clearly explains such limitation:
>> "Cannot use streaming aggregation before joins.")
>>
>> If you can achieve with mapGroupWithState, you may want to stick with
>> that.
>>
>> Btw, when you deal with streaming, you may want to define logical batch
>> for all aggregations and joins via defining window and watermark. You
>> wouldn't want to get different result according to the micro-batch, and
>> then you always want to deal with event time window.
>>
>> Thanks,
>> Jungtaek Lim (HeartSaVioR)
>>
>> 2018년 4월 18일 (수) 오전 3:42, kant kodali <[email protected]>님이 작성:
>>
>>> Hi TD,
>>>
>>> Thanks for that. The only reason I ask is I don't see any alternative
>>> solution to solve the problem below using raw sql.
>>>
>>>
>>> How to select the max row for every group in spark structured streaming
>>> 2.3.0 without using order by since it requires complete mode or
>>> mapGroupWithState?
>>>
>>> *Input:*
>>>
>>> id | amount | my_timestamp
>>> -------------------------------------------
>>> 1 | 5 | 2018-04-01T01:00:00.000Z
>>> 1 | 10 | 2018-04-01T01:10:00.000Z
>>> 2 | 20 | 2018-04-01T01:20:00.000Z
>>> 2 | 30 | 2018-04-01T01:25:00.000Z
>>> 2 | 40 | 2018-04-01T01:30:00.000Z
>>>
>>> *Expected Output:*
>>>
>>> id | amount | my_timestamp
>>> -------------------------------------------
>>> 1 | 10 | 2018-04-01T01:10:00.000Z
>>> 2 | 40 | 2018-04-01T01:30:00.000Z
>>>
>>> Looking for a streaming solution using either raw sql like
>>> sparkSession.sql("sql
>>> query") or similar to raw sql but not something like mapGroupWithState
>>>
>>> On Mon, Apr 16, 2018 at 8:32 PM, Tathagata Das <
>>> [email protected]> wrote:
>>>
>>>> Unfortunately no. Honestly it does not make sense as for type-aware
>>>> operations like map, mapGroups, etc., you have to provide an actual JVM
>>>> function. That does not fit in with the SQL language structure.
>>>>
>>>> On Mon, Apr 16, 2018 at 7:34 PM, kant kodali <[email protected]>
>>>> wrote:
>>>>
>>>>> Hi All,
>>>>>
>>>>> can we use mapGroupsWithState in raw SQL? or is it in the roadmap?
>>>>>
>>>>> Thanks!
>>>>>
>>>>>
>>>>>
>>>>
>>>