Izhar Ahmed created SPARK-23746:
-----------------------------------

             Summary: HashMap UserDefinedType giving cast exception in Spark 
1.6.2 while implementing UDAF
                 Key: SPARK-23746
                 URL: https://issues.apache.org/jira/browse/SPARK-23746
             Project: Spark
          Issue Type: Bug
          Components: Spark Core, SQL
    Affects Versions: 1.6.2
            Reporter: Izhar Ahmed


I am trying to use a custom HashMap implementation as UserDefinedType instead 
of MapType in spark. The code is *working fine in spark 1.5.2* but giving 
{{java.lang.ClassCastException: scala.collection.immutable.HashMap$HashMap1 
cannot be cast to org.apache.spark.sql.catalyst.util.MapData}} *exception in 
spark 1.6.2*

The code:- 
{code:java}
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, 
UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

import scala.collection.immutable.HashMap

class Test extends UserDefinedAggregateFunction {

  def inputSchema: StructType =
    StructType(Array(StructField("input", StringType)))

  def bufferSchema = StructType(Array(StructField("top_n", CustomHashMapType)))

  def dataType: DataType = CustomHashMapType

  def deterministic = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = HashMap.empty[String, Long]
  }

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val buff0 = buffer.getAs[HashMap[String, Long]](0)
    buffer(0) = buff0.updated("test", buff0.getOrElse("test", 0L) + 1L)
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    buffer1(0) = buffer1.
      getAs[HashMap[String, Long]](0)
      .merged(buffer2.getAs[HashMap[String, Long]](0))({ case ((k, v1), (_, 
v2)) => (k, v1 + v2) })
  }

  def evaluate(buffer: Row): Any = {
    buffer(0)
  }
}

private case object CustomHashMapType extends UserDefinedType[HashMap[String, 
Long]] {

  override def sqlType: DataType = MapType(StringType, LongType)

  override def serialize(obj: Any): Map[String, Long] =
    obj.asInstanceOf[Map[String, Long]]

  override def deserialize(datum: Any): HashMap[String, Long] = {
    datum.asInstanceOf[Map[String, Long]] ++: HashMap.empty[String, Long]
  }

  override def userClass: Class[HashMap[String, Long]] = 
classOf[HashMap[String, Long]]

}
{code}
The wrapper Class to run the UDAF:-
{code:scala}
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}

object TestJob {

  def main(args: Array[String]): Unit = {
    val conf = new 
SparkConf().setMaster("local[4]").setAppName("DataStatsExecution")
    val sc = new SparkContext(conf)

    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._
    val df = sc.parallelize(Seq(1,2,3,4)).toDF("col")
    val udaf = new Test()
    val outdf = df.agg(udaf(df("col")))
    outdf.show
  }
}
{code}

Stacktrace:-
{code:java}
Caused by: java.lang.ClassCastException: 
scala.collection.immutable.HashMap$HashMap1 cannot be cast to 
org.apache.spark.sql.catalyst.util.MapData
    at 
org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getMap(rows.scala:50)
    at 
org.apache.spark.sql.catalyst.expressions.GenericMutableRow.getMap(rows.scala:248)
    at 
org.apache.spark.sql.catalyst.expressions.JoinedRow.getMap(JoinedRow.scala:115)
    at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown
 Source)
    at 
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$31.apply(AggregationIterator.scala:345)
    at 
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$31.apply(AggregationIterator.scala:344)
    at 
org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:154)
    at 
org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
    at 
org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:149)
    at 
org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73)
    at 
org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41)
    at org.apache.spark.scheduler.Task.run(Task.scala:89)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:227)
    at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

{code}



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to