[ 
https://issues.apache.org/jira/browse/SPARK-11372?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Pravin Gadakh updated SPARK-11372:
----------------------------------
    Description: 
Consider following custom UDAF which uses StringType column as intermediate 
buffer's schema: 

{code:title=Merge.scala|borderStyle=solid}
class Merge extends UserDefinedAggregateFunction {

  override def inputSchema: StructType = StructType(StructField("value", 
StringType) :: Nil)

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = if (buffer.getAs[String](0).isEmpty) {
      input.getAs[String](0)
    } else {
      buffer.getAs[String](0) + "," + input.getAs[String](0)
    }
  }

  override def bufferSchema: StructType = StructType(
    StructField("merge", StringType) :: Nil
  )

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = if (buffer1.getAs[String](0).isEmpty) {
      buffer2.getAs[String](0)
    } else if (!buffer2.getAs[String](0).isEmpty) {
      buffer1.getAs[String](0) + "," + buffer2.getAs[String](0)
    }
  }

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = ""
  }

  override def deterministic: Boolean = true

  override def evaluate(buffer: Row): Any = {
    buffer.getAs[String](0)
  }

  override def dataType: DataType = StringType
}
{code}

Running the UDAF: 
{code}
  val df: DataFrame = sc.parallelize(List("abcd", "efgh", "hijk")).toDF("alpha")
  val merge = new Merge()
  df.groupBy().agg(merge(col("alpha")).as("Merge")).show()
{code}

Throws following exception: 

{code}
java.lang.ClassCastException: java.lang.String cannot be cast to 
org.apache.spark.unsafe.types.UTF8String
  at 
org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getUTF8String(rows.scala:45)
  at 
org.apache.spark.sql.catalyst.expressions.GenericMutableRow.getUTF8String(rows.scala:247)
  at 
org.apache.spark.sql.catalyst.expressions.JoinedRow.getUTF8String(JoinedRow.scala:102)
  at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown
 Source)
  at 
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$32.apply(AggregationIterator.scala:373)
  at 
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$32.apply(AggregationIterator.scala:362)
  at 
org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:141)
  at 
org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:30)
  at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
  at scala.collection.Iterator$$anon$10.next(Iterator.scala:312)
  at scala.collection.Iterator$class.foreach(Iterator.scala:727)
  at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
  at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48)
  at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103)
  at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47)
  at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273)
  at scala.collection.AbstractIterator.to(Iterator.scala:1157)
  at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265)
  at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157)
  at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252)
  at scala.collection.AbstractIterator.toArray(Iterator.scala:1157)
  at 
org.apache.spark.sql.execution.SparkPlan$$anonfun$5.apply(SparkPlan.scala:215)
  at 
org.apache.spark.sql.execution.SparkPlan$$anonfun$5.apply(SparkPlan.scala:215)
  at 
org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1839)
  at 
org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1839)
  at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
  at org.apache.spark.scheduler.Task.run(Task.scala:88)
  at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
  at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
  at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
  at java.lang.Thread.run(Thread.java:745)
{code}

  was:
Consider following custom UDAF which uses StringType column as intermediate 
buffer's schema: 

{code:title=Merge.scala|borderStyle=solid}
class Merge extends UserDefinedAggregateFunction {

  override def inputSchema: StructType = StructType(StructField("value", 
StringType) :: Nil)

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = if (buffer.getAs[String](0).isEmpty) {
      input.getAs[String](0)
    } else {
      buffer.getAs[String](0) + "," + input.getAs[String](0)
    }
  }

  override def bufferSchema: StructType = StructType(
    StructField("merge", StringType) :: Nil
  )

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = if (buffer1.getAs[String](0).isEmpty) {
      buffer2.getAs[String](0)
    } else if (!buffer2.getAs[String](0).isEmpty) {
      buffer1.getAs[String](0) + "," + buffer2.getAs[String](0)
    }
  }

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = ""
  }

  override def deterministic: Boolean = true

  override def evaluate(buffer: Row): Any = {
    buffer.getAs[String](0)
  }

  override def dataType: DataType = StringType
}
{code}

Running the UDAF: 
{code}
  val df: DataFrame = sc.parallelize(List("abcd", "efgh", "hijk")).toDF("alpha")
  val merge = new Merge()
  df.groupBy().agg(merge(col("alpha")).as("Merge")).show()
{code}

Throw following exception: 

{code}
java.lang.ClassCastException: java.lang.String cannot be cast to 
org.apache.spark.unsafe.types.UTF8String
  at 
org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getUTF8String(rows.scala:45)
  at 
org.apache.spark.sql.catalyst.expressions.GenericMutableRow.getUTF8String(rows.scala:247)
  at 
org.apache.spark.sql.catalyst.expressions.JoinedRow.getUTF8String(JoinedRow.scala:102)
  at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown
 Source)
  at 
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$32.apply(AggregationIterator.scala:373)
  at 
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$32.apply(AggregationIterator.scala:362)
  at 
org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:141)
  at 
org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:30)
  at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
  at scala.collection.Iterator$$anon$10.next(Iterator.scala:312)
  at scala.collection.Iterator$class.foreach(Iterator.scala:727)
  at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
  at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48)
  at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103)
  at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47)
  at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273)
  at scala.collection.AbstractIterator.to(Iterator.scala:1157)
  at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265)
  at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157)
  at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252)
  at scala.collection.AbstractIterator.toArray(Iterator.scala:1157)
  at 
org.apache.spark.sql.execution.SparkPlan$$anonfun$5.apply(SparkPlan.scala:215)
  at 
org.apache.spark.sql.execution.SparkPlan$$anonfun$5.apply(SparkPlan.scala:215)
  at 
org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1839)
  at 
org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1839)
  at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
  at org.apache.spark.scheduler.Task.run(Task.scala:88)
  at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
  at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
  at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
  at java.lang.Thread.run(Thread.java:745)
{code}


> custom UDAF with StringType throws java.lang.ClassCastException
> ---------------------------------------------------------------
>
>                 Key: SPARK-11372
>                 URL: https://issues.apache.org/jira/browse/SPARK-11372
>             Project: Spark
>          Issue Type: Bug
>          Components: SQL
>    Affects Versions: 1.5.0
>            Reporter: Pravin Gadakh
>            Priority: Minor
>
> Consider following custom UDAF which uses StringType column as intermediate 
> buffer's schema: 
> {code:title=Merge.scala|borderStyle=solid}
> class Merge extends UserDefinedAggregateFunction {
>   override def inputSchema: StructType = StructType(StructField("value", 
> StringType) :: Nil)
>   override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
>     buffer(0) = if (buffer.getAs[String](0).isEmpty) {
>       input.getAs[String](0)
>     } else {
>       buffer.getAs[String](0) + "," + input.getAs[String](0)
>     }
>   }
>   override def bufferSchema: StructType = StructType(
>     StructField("merge", StringType) :: Nil
>   )
>   override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = 
> {
>     buffer1(0) = if (buffer1.getAs[String](0).isEmpty) {
>       buffer2.getAs[String](0)
>     } else if (!buffer2.getAs[String](0).isEmpty) {
>       buffer1.getAs[String](0) + "," + buffer2.getAs[String](0)
>     }
>   }
>   override def initialize(buffer: MutableAggregationBuffer): Unit = {
>     buffer(0) = ""
>   }
>   override def deterministic: Boolean = true
>   override def evaluate(buffer: Row): Any = {
>     buffer.getAs[String](0)
>   }
>   override def dataType: DataType = StringType
> }
> {code}
> Running the UDAF: 
> {code}
>   val df: DataFrame = sc.parallelize(List("abcd", "efgh", 
> "hijk")).toDF("alpha")
>   val merge = new Merge()
>   df.groupBy().agg(merge(col("alpha")).as("Merge")).show()
> {code}
> Throws following exception: 
> {code}
> java.lang.ClassCastException: java.lang.String cannot be cast to 
> org.apache.spark.unsafe.types.UTF8String
>   at 
> org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getUTF8String(rows.scala:45)
>   at 
> org.apache.spark.sql.catalyst.expressions.GenericMutableRow.getUTF8String(rows.scala:247)
>   at 
> org.apache.spark.sql.catalyst.expressions.JoinedRow.getUTF8String(JoinedRow.scala:102)
>   at 
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown
>  Source)
>   at 
> org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$32.apply(AggregationIterator.scala:373)
>   at 
> org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$32.apply(AggregationIterator.scala:362)
>   at 
> org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:141)
>   at 
> org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:30)
>   at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
>   at scala.collection.Iterator$$anon$10.next(Iterator.scala:312)
>   at scala.collection.Iterator$class.foreach(Iterator.scala:727)
>   at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
>   at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48)
>   at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103)
>   at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47)
>   at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273)
>   at scala.collection.AbstractIterator.to(Iterator.scala:1157)
>   at 
> scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265)
>   at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157)
>   at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252)
>   at scala.collection.AbstractIterator.toArray(Iterator.scala:1157)
>   at 
> org.apache.spark.sql.execution.SparkPlan$$anonfun$5.apply(SparkPlan.scala:215)
>   at 
> org.apache.spark.sql.execution.SparkPlan$$anonfun$5.apply(SparkPlan.scala:215)
>   at 
> org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1839)
>   at 
> org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1839)
>   at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
>   at org.apache.spark.scheduler.Task.run(Task.scala:88)
>   at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
>   at 
> java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
>   at 
> java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
>   at java.lang.Thread.run(Thread.java:745)
> {code}



--
This message was sent by Atlassian JIRA
(v6.3.4#6332)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to