I am still stuck with this. Anyone knows the correct way to use the custom
Aggregator for the case class in agg way?
I like to use Dataset API, but it looks like in aggregation, Spark lost the
Type, and back to GenericRowWithSchema, instead of my case class. Is that right?
Thanks
From: Yong Zhang
Sent: Thursday, March 22, 2018 10:08 PM
To: user@spark.apache.org
Subject: java.lang.ClassCastException:
org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast
to Case class
I am trying to research a custom Aggregator implementation, and following the
example in the Spark sample code here:
https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala
But I cannot use it in the agg function, and got the error like
java.lang.ClassCastException:
org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast
to my case class. If I don't use the group by, then it works as in the same way
in the sample code. To make it with group by, what I need to change?
This is on Spark 2.2, as shown below. Following the spark example, I can do
rawDS.select(ChangeLogAggregator.toColumn.name("change_log")).show(false)
without any issue, but if
rawDS.groupBy($"domain").agg(ChangeLogAggregator.toColumn.name("change_log")).show(false)
I will get the cast exception. But I want to apply my custom Aggregator
implementation per group. How do I fix this?
Thanks
scala> spark.version
res31: String = 2.2.1
case class FlagChangeLog(date: String, old_flag: Boolean, new_flag: Boolean)
case class DeriveRecord (domain: String, date: String, flag: Boolean, isDelta:
Boolean, flag_changelog: scala.collection.mutable.ListBuffer[FlagChangeLog])
val rawDS = Seq(
DeriveRecord("abc.com", "2017-01-09", true, false, ListBuffer.empty),
DeriveRecord("123.com", "2015-01-01", false, false, ListBuffer.empty),
DeriveRecord("abc.com", "2018-01-09", false, true, ListBuffer.empty),
DeriveRecord("123.com", "2017-01-09", true, true, ListBuffer.empty),
DeriveRecord("xyz.com", "2018-03-09", false, true, ListBuffer.empty)
).toDS
scala> rawDS.show(false)
+---+--+-+---+--+
|domain |date |flag |isDelta|flag_changelog|
+---+--+-+---+--+
|abc.com|2017-01-09|true |false |[]|
|123.com|2015-01-01|false|false |[]|
|abc.com|2018-01-09|false|true |[]|
|123.com|2017-01-09|true |true |[]|
|xyz.com|2018-03-09|false|true |[]|
+---+--+-+---+--+
object ChangeLogAggregator extends Aggregator[DeriveRecord, DeriveRecord,
DeriveRecord] {
def zero: DeriveRecord = ///
def reduce(buffer: DeriveRecord, curr: DeriveRecord): DeriveRecord = {
/// ommit
}
def merge(b1: DeriveRecord, b2: DeriveRecord): DeriveRecord = {
/// ommit
}
def finish(output: DeriveRecord): DeriveRecord = {
/// ommit
}
def bufferEncoder: Encoder[DeriveRecord] = Encoders.product
def outputEncoder: Encoder[DeriveRecord] = Encoders.product
}
scala> rawDS.select(ChangeLogAggregator.toColumn.name("change_log")).show(false)
+---+--+-+---+---+
|domain |date |flag |isDelta|flag_changelog
|
+---+--+-+---+---+
|abc.com|2018-01-09|false|true |[[2015-01-01,true,false],
[2018-01-09,false,false]]|
+---+--+-+---+---+
scala>
rawDS.groupBy($"domain").agg(ChangeLogAggregator.toColumn.name("change_log")).show(false)
18/03/22 22:04:44 ERROR Executor: Exception in task 1.0 in stage 36.0 (TID 48)
java.lang.ClassCastException:
org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast
to $line15.$read$$iw$$iw$DeriveRecord
at
$line110.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$ChangeLogAggregator$.reduce(:31)
at
org.apache.spark.sql.execution.aggregate.ComplexTypedAggregateExpression.update(TypedAggregateExpression.scala:239)
at
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.update(interfaces.scala:524)
at
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171)
at
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171)
at
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:187)
at
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:181)
at