Change: val arrayinput = input.getAs[Array[String]](0)
to: val arrayinput = input.getAs[Seq[String]](0) Yong ________________________________ From: shyla deshpande <deshpandesh...@gmail.com> Sent: Thursday, March 23, 2017 8:18 PM To: user Subject: Spark dataframe, UserDefinedAggregateFunction(UDAF) help!! This is my input data. The UDAF needs to aggregate the goals for a team and return a map that gives the count for every goal in the team. I am getting the following error java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to [Ljava.lang.String; at com.whil.common.GoalAggregator.update(GoalAggregator.scala:27) +------+--------------+ |teamid|goals | +------+--------------+ |t1 |[Goal1, Goal2]| |t1 |[Goal1, Goal3]| |t2 |[Goal1, Goal2]| |t3 |[Goal2, Goal3]| +------+--------------+ root |-- teamid: string (nullable = true) |-- goals: array (nullable = true) | |-- element: string (containsNull = true) /////////////////////////Calling the UDAF////////// object TestUDAF { def main(args: Array[String]): Unit = { val spark = SparkSession .builder .getOrCreate() val sc: SparkContext = spark.sparkContext val sqlContext = spark.sqlContext import sqlContext.implicits._ val data = Seq( ("t1", Seq("Goal1", "Goal2")), ("t1", Seq("Goal1", "Goal3")), ("t2", Seq("Goal1", "Goal2")), ("t3", Seq("Goal2", "Goal3"))).toDF("teamid","goals") data.show(truncate = false) data.printSchema() import spark.implicits._ val sumgoals = new GoalAggregator val result = data.groupBy("teamid").agg(sumgoals(col("goals"))) result.show(truncate = false) } } ///////////////UDAF///////////////// import org.apache.spark.sql.expressions.MutableAggregationBuffer import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.Row import org.apache.spark.sql.types._ class GoalAggregator extends UserDefinedAggregateFunction{ override def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("value", ArrayType(StringType)) :: Nil) override def bufferSchema: StructType = StructType( StructField("combined", MapType(StringType,IntegerType)) :: Nil ) override def dataType: DataType = MapType(StringType,IntegerType) override def deterministic: Boolean = true override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer.update(0, Map[String, Integer]()) } override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { val mapbuf = buffer.getAs[Map[String, Int]](0) val arrayinput = input.getAs[Array[String]](0) val result = mapbuf ++ arrayinput.map(goal => { val cnt = mapbuf.get(goal).getOrElse(0) + 1 goal -> cnt }) buffer.update(0, result) } override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { val map1 = buffer1.getAs[Map[String, Int]](0) val map2 = buffer2.getAs[Map[String, Int]](0) val result = map1 ++ map2.map { case (k,v) => val cnt = map1.get(k).getOrElse(0) + 1 k -> cnt } buffer1.update(0, result) } override def evaluate(buffer: Row): Any = { buffer.getAs[Map[String, Int]](0) } }