Repository: spark
Updated Branches:
  refs/heads/master 769aa0f1d -> 93aa42715


[SPARK-19691][SQL] Fix ClassCastException when calculating percentile of 
decimal column

## What changes were proposed in this pull request?
This pr fixed a class-cast exception below;
```
scala> spark.range(10).selectExpr("cast (id as decimal) as 
x").selectExpr("percentile(x, 0.5)").collect()
 java.lang.ClassCastException: org.apache.spark.sql.types.Decimal cannot be 
cast to java.lang.Number
        at 
org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:141)
        at 
org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:58)
        at 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.update(interfaces.scala:514)
        at 
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171)
        at 
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171)
        at 
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:187)
        at 
org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:181)
        at 
org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:151)
        at 
org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:78)
        at 
org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:109)
        at
```
This fix simply converts catalyst values (i.e., `Decimal`) into scala ones by 
using `CatalystTypeConverters`.

## How was this patch tested?
Added a test in `DataFrameSuite`.

Author: Takeshi Yamamuro <yamam...@apache.org>

Closes #17028 from maropu/SPARK-19691.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/93aa4271
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/93aa4271
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/93aa4271

Branch: refs/heads/master
Commit: 93aa4271596a30752dc5234d869c3ae2f6e8e723
Parents: 769aa0f
Author: Takeshi Yamamuro <yamam...@apache.org>
Authored: Thu Feb 23 16:28:36 2017 +0100
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Thu Feb 23 16:28:36 2017 +0100

----------------------------------------------------------------------
 .../expressions/aggregate/Percentile.scala      | 48 +++++++++++---------
 .../expressions/aggregate/PercentileSuite.scala | 29 ++++++------
 .../org/apache/spark/sql/DataFrameSuite.scala   |  5 ++
 3 files changed, 45 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/93aa4271/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 6b7cf79..8433a93 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, 
DataOutputStream}
 import java.util
 
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, 
TypeCheckSuccess}
 import org.apache.spark.sql.catalyst.expressions._
@@ -61,7 +61,7 @@ case class Percentile(
     frequencyExpression : Expression,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
-  extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with 
ImplicitCastInputTypes {
+  extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with 
ImplicitCastInputTypes {
 
   def this(child: Expression, percentageExpression: Expression) = {
     this(child, percentageExpression, Literal(1L), 0, 0)
@@ -130,15 +130,20 @@ case class Percentile(
     }
   }
 
-  override def createAggregationBuffer(): OpenHashMap[Number, Long] = {
+  private def toDoubleValue(d: Any): Double = d match {
+    case d: Decimal => d.toDouble
+    case n: Number => n.doubleValue
+  }
+
+  override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = {
     // Initialize new counts map instance here.
-    new OpenHashMap[Number, Long]()
+    new OpenHashMap[AnyRef, Long]()
   }
 
   override def update(
-      buffer: OpenHashMap[Number, Long],
-      input: InternalRow): OpenHashMap[Number, Long] = {
-    val key = child.eval(input).asInstanceOf[Number]
+      buffer: OpenHashMap[AnyRef, Long],
+      input: InternalRow): OpenHashMap[AnyRef, Long] = {
+    val key = child.eval(input).asInstanceOf[AnyRef]
     val frqValue = frequencyExpression.eval(input)
 
     // Null values are ignored in counts map.
@@ -155,32 +160,32 @@ case class Percentile(
   }
 
   override def merge(
-      buffer: OpenHashMap[Number, Long],
-      other: OpenHashMap[Number, Long]): OpenHashMap[Number, Long] = {
+      buffer: OpenHashMap[AnyRef, Long],
+      other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = {
     other.foreach { case (key, count) =>
       buffer.changeValue(key, count, _ + count)
     }
     buffer
   }
 
-  override def eval(buffer: OpenHashMap[Number, Long]): Any = {
+  override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
     generateOutput(getPercentiles(buffer))
   }
 
-  private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = 
{
+  private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = 
{
     if (buffer.isEmpty) {
       return Seq.empty
     }
 
     val sortedCounts = buffer.toSeq.sortBy(_._1)(
-      
child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]])
+      
child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
     val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) {
       case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
     }.tail
     val maxPosition = accumlatedCounts.last._2 - 1
 
     percentages.map { percentile =>
-      getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue()
+      getPercentile(accumlatedCounts, maxPosition * percentile)
     }
   }
 
@@ -200,7 +205,7 @@ case class Percentile(
    * This function has been based upon similar function from HIVE
    * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
    */
-  private def getPercentile(aggreCounts: Seq[(Number, Long)], position: 
Double): Number = {
+  private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: 
Double): Double = {
     // We may need to do linear interpolation to get the exact percentile
     val lower = position.floor.toLong
     val higher = position.ceil.toLong
@@ -213,18 +218,17 @@ case class Percentile(
     val lowerKey = aggreCounts(lowerIndex)._1
     if (higher == lower) {
       // no interpolation needed because position does not have a fraction
-      return lowerKey
+      return toDoubleValue(lowerKey)
     }
 
     val higherKey = aggreCounts(higherIndex)._1
     if (higherKey == lowerKey) {
       // no interpolation needed because lower position and higher position 
has the same key
-      return lowerKey
+      return toDoubleValue(lowerKey)
     }
 
     // Linear interpolation to get the exact percentile
-    return (higher - position) * lowerKey.doubleValue() +
-      (position - lower) * higherKey.doubleValue()
+    (higher - position) * toDoubleValue(lowerKey) + (position - lower) * 
toDoubleValue(higherKey)
   }
 
   /**
@@ -238,7 +242,7 @@ case class Percentile(
     }
   }
 
-  override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = {
+  override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = {
     val buffer = new Array[Byte](4 << 10)  // 4K
     val bos = new ByteArrayOutputStream()
     val out = new DataOutputStream(bos)
@@ -261,11 +265,11 @@ case class Percentile(
     }
   }
 
-  override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = {
+  override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = {
     val bis = new ByteArrayInputStream(bytes)
     val ins = new DataInputStream(bis)
     try {
-      val counts = new OpenHashMap[Number, Long]
+      val counts = new OpenHashMap[AnyRef, Long]
       // Read unsafeRow size and content in bytes.
       var sizeOfNextRow = ins.readInt()
       while (sizeOfNextRow >= 0) {
@@ -274,7 +278,7 @@ case class Percentile(
         val row = new UnsafeRow(2)
         row.pointTo(bs, sizeOfNextRow)
         // Insert the pairs into counts map.
-        val key = row.get(0, child.dataType).asInstanceOf[Number]
+        val key = row.get(0, child.dataType)
         val count = row.get(1, LongType).asInstanceOf[Long]
         counts.update(key, count)
         sizeOfNextRow = ins.readInt()

http://git-wip-us.apache.org/repos/asf/spark/blob/93aa4271/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
index 1533fe5..2420ba5 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
@@ -21,7 +21,6 @@ import org.apache.spark.SparkException
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
-import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.types._
@@ -39,12 +38,12 @@ class PercentileSuite extends SparkFunSuite {
     val agg = new Percentile(BoundReference(0, IntegerType, true), 
Literal(0.5))
 
     // Check empty serialize and deserialize
-    val buffer = new OpenHashMap[Number, Long]()
+    val buffer = new OpenHashMap[AnyRef, Long]()
     assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
 
     // Check non-empty buffer serializa and deserialize.
     data.foreach { key =>
-      buffer.changeValue(key, 1L, _ + 1L)
+      buffer.changeValue(new Integer(key), 1L, _ + 1L)
     }
     assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
   }
@@ -58,25 +57,25 @@ class PercentileSuite extends SparkFunSuite {
     val agg = new Percentile(childExpression, percentageExpression)
 
     // Test with rows without frequency
-    val rows = (1 to count).map( x => Seq(x))
-    runTest( agg, rows, expectedPercentiles)
+    val rows = (1 to count).map(x => Seq(x))
+    runTest(agg, rows, expectedPercentiles)
 
     // Test with row with frequency. Second and third columns are frequency in 
Int and Long
     val countForFrequencyTest = 1000
-    val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ 
x.toLong)
+    val rowsWithFrequency = (1 to countForFrequencyTest).map(x => Seq(x, x):+ 
x.toLong)
     val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0)
 
     val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = 
false)
     val aggInt = new Percentile(childExpression, percentageExpression, 
frequencyExpressionInt)
-    runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
+    runTest(aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
 
     val frequencyExpressionLong = BoundReference(2, LongType, nullable = false)
     val aggLong = new Percentile(childExpression, percentageExpression, 
frequencyExpressionLong)
-    runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
+    runTest(aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
 
     // Run test with Flatten data
-    val flattenRows = (1 to countForFrequencyTest).flatMap( current =>
-      (1 to current).map( y => current )).map( Seq(_))
+    val flattenRows = (1 to countForFrequencyTest).flatMap(current =>
+      (1 to current).map(y => current )).map(Seq(_))
     runTest(agg, flattenRows, expectedPercentilesWithFrquency)
   }
 
@@ -153,7 +152,7 @@ class PercentileSuite extends SparkFunSuite {
     }
 
     val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType)
-    for ( dataType <- validDataTypes;
+    for (dataType <- validDataTypes;
       frequencyType <- validFrequencyTypes)  {
       val child = AttributeReference("a", dataType)()
       val frq = AttributeReference("frq", frequencyType)()
@@ -176,7 +175,7 @@ class PercentileSuite extends SparkFunSuite {
         StringType, DateType, TimestampType,
       CalendarIntervalType, NullType)
 
-    for( dataType <- invalidDataTypes;
+    for(dataType <- invalidDataTypes;
         frequencyType <- validFrequencyTypes) {
       val child = AttributeReference("a", dataType)()
       val frq = AttributeReference("frq", frequencyType)()
@@ -186,7 +185,7 @@ class PercentileSuite extends SparkFunSuite {
             s"'`a`' is of ${dataType.simpleString} type."))
     }
 
-    for( dataType <- validDataTypes;
+    for(dataType <- validDataTypes;
         frequencyType <- invalidFrequencyDataTypes) {
       val child = AttributeReference("a", dataType)()
       val frq = AttributeReference("frq", frequencyType)()
@@ -294,11 +293,11 @@ class PercentileSuite extends SparkFunSuite {
         agg.update(buffer, InternalRow(1, -5))
         agg.eval(buffer)
       }
-    assert( caught.getMessage.startsWith("Negative values found in "))
+    assert(caught.getMessage.startsWith("Negative values found in "))
   }
 
   private def compareEquals(
-      left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): 
Boolean = {
+      left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): 
Boolean = {
     left.size == right.size && left.forall { case (key, count) =>
       right.apply(key) == count
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/93aa4271/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index e6338ab..5e65436 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1702,4 +1702,9 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     val df = Seq(123L -> "123", 19157170390056973L -> 
"19157170390056971").toDF("i", "j")
     checkAnswer(df.select($"i" === $"j"), Row(true) :: Row(false) :: Nil)
   }
+
+  test("SPARK-19691 Calculating percentile of decimal column fails with 
ClassCastException") {
+    val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as 
x").selectExpr("percentile(x, 0.5)")
+    checkAnswer(df, Row(BigDecimal(0.0)) :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to