I think that you are hitting a bug (which should be fixed in Spark 1.5.1).
I'm hoping we can cut an RC for that this week.  Until then you could try
building branch-1.5.

On Tue, Sep 22, 2015 at 11:13 AM, Deenar Toraskar <deenar.toras...@gmail.com
> wrote:

> Hi
>
> I am trying to write an UDAF ArraySum, that does element wise sum of
> arrays of Doubles returning an array of Double following the sample in
>
> https://databricks.com/blog/2015/09/16/spark-1-5-dataframe-api-highlights-datetimestring-handling-time-intervals-and-udafs.html.
> I am getting the following error. Any guidance on handle complex type in
> Spark SQL would be appreciated.
>
> Regards
> Deenar
>
> import org.apache.spark.sql.expressions.MutableAggregationBuffer
> import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
> import org.apache.spark.sql.Row
> import org.apache.spark.sql.types._
> import org.apache.spark.sql.functions._
>
> class ArraySum extends UserDefinedAggregateFunction {
>    def inputSchema: org.apache.spark.sql.types.StructType =
>     StructType(StructField("value", ArrayType(DoubleType, false)) :: Nil)
>
>   def bufferSchema: StructType =
>     StructType(StructField("value", ArrayType(DoubleType, false)) :: Nil)
>
>   def dataType: DataType = ArrayType(DoubleType, false)
>
>   def deterministic: Boolean = true
>
>   def initialize(buffer: MutableAggregationBuffer): Unit = {
>     buffer(0) = Nil
>   }
>
>   def update(buffer: MutableAggregationBuffer,input: Row): Unit = {
>     val currentSum : Seq[Double] = buffer.getSeq(0)
>     val currentRow : Seq[Double] = input.getSeq(0)
>     buffer(0) = (currentSum, currentRow) match {
>       case (Nil, Nil) => Nil
>       case (Nil, row) => row
>       case (sum, Nil) => sum
>       case (sum, row) => (seq, anotherSeq).zipped.map{ case (a, b) => a +
> b }
>       // TODO handle different sizes arrays here
>     }
>   }
>
>   def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
>     val currentSum : Seq[Double] = buffer1.getSeq(0)
>     val currentRow : Seq[Double] = buffer2.getSeq(0)
>     buffer1(0) = (currentSum, currentRow) match {
>       case (Nil, Nil) => Nil
>       case (Nil, row) => row
>       case (sum, Nil) => sum
>       case (sum, row) => (seq, anotherSeq).zipped.map{ case (a, b) => a +
> b }
>       // TODO handle different sizes arrays here
>     }
>   }
>
>   def evaluate(buffer: Row): Any = {
>     buffer.getSeq(0)
>   }
> }
>
> val arraySum = new ArraySum
> sqlContext.udf.register("ArraySum", arraySum)
>
> *%sql select ArraySum(Array(1.0,2.0,3.0)) from pnls where date =
> '2015-05-22' limit 10*
>
> gives me the following error
>
>
> Error in SQL statement: SparkException: Job aborted due to stage failure:
> Task 0 in stage 219.0 failed 4 times, most recent failure: Lost task 0.3 in
> stage 219.0 (TID 11242, 10.172.255.236): java.lang.ClassCastException:
> scala.collection.mutable.WrappedArray$ofRef cannot be cast to
> org.apache.spark.sql.types.ArrayData at
> org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getArray(rows.scala:47)
> at
> org.apache.spark.sql.catalyst.expressions.GenericMutableRow.getArray(rows.scala:247)
> at
> org.apache.spark.sql.catalyst.expressions.JoinedRow.getArray(JoinedRow.scala:108)
> at
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown
> Source) at
> org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$32.apply(AggregationIterator.scala:373)
> at
> org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$32.apply(AggregationIterator.scala:362)
> at
> org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:141)
> at
> org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:30)
> at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at
> scala.collection.Iterator$$anon$10.next(Iterator.scala:312) at
> scala.collection.Iterator$class.foreach(Iterator.scala:727) at
> scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at
> scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48) at
> scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103)
> at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47)
> at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273)
> at scala.collection.AbstractIterator.to(Iterator.scala:1157) at
> scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265)
> at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157) at
> scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252)
> at scala.collection.AbstractIterator.toArray(Iterator.scala:1157) at
> org.apache.spark.sql.execution.SparkPlan$$anonfun$5.apply(SparkPlan.scala:215)
> at
> org.apache.spark.sql.execution.SparkPlan$$anonfun$5.apply(SparkPlan.scala:215)
> at
> org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1839)
> at
> org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1839)
> at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66) at
> org.apache.spark.scheduler.Task.run(Task.scala:88) at
> org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214) at
> java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
> at
> java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
> at java.lang.Thread.run(Thread.java:745)
>
>
>

Reply via email to