Repository: spark
Updated Branches:
  refs/heads/master 20adf9aa1 -> 365a29bdb


[SPARK-22100][SQL] Make percentile_approx support date/timestamp type and 
change the output type to be the same as input type

## What changes were proposed in this pull request?

The `percentile_approx` function previously accepted numeric type input and 
output double type results.

But since all numeric types, date and timestamp types are represented as 
numerics internally, `percentile_approx` can support them easily.

After this PR, it supports date type, timestamp type and numeric types as input 
types. The result type is also changed to be the same as the input type, which 
is more reasonable for percentiles.

This change is also required when we generate equi-height histograms for these 
types.

## How was this patch tested?

Added a new test and modified some existing tests.

Author: Zhenhua Wang <wangzhen...@huawei.com>

Closes #19321 from wzhfy/approx_percentile_support_types.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/365a29bd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/365a29bd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/365a29bd

Branch: refs/heads/master
Commit: 365a29bdbfd18aae4b5374157dc1d2abfc64eb0e
Parents: 20adf9a
Author: Zhenhua Wang <wangzhen...@huawei.com>
Authored: Mon Sep 25 09:28:42 2017 -0700
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Mon Sep 25 09:28:42 2017 -0700

----------------------------------------------------------------------
 R/pkg/tests/fulltests/test_sparkSQL.R           |  4 +--
 docs/sql-programming-guide.md                   |  1 +
 python/pyspark/sql/dataframe.py                 | 10 +++---
 .../aggregate/ApproximatePercentile.scala       | 33 +++++++++++++++++---
 .../aggregate/ApproximatePercentileSuite.scala  |  6 ++--
 .../sql/ApproximatePercentileQuerySuite.scala   | 29 ++++++++++++++++-
 .../org/apache/spark/sql/DataFrameSuite.scala   |  6 ++--
 7 files changed, 70 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/R/pkg/tests/fulltests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R 
b/R/pkg/tests/fulltests/test_sparkSQL.R
index 4d1010e..4e62be9 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -2538,14 +2538,14 @@ test_that("describe() and summary() on a DataFrame", {
 
   stats2 <- summary(df)
   expect_equal(collect(stats2)[5, "summary"], "25%")
-  expect_equal(collect(stats2)[5, "age"], "30.0")
+  expect_equal(collect(stats2)[5, "age"], "30")
 
   stats3 <- summary(df, "min", "max", "55.1%")
 
   expect_equal(collect(stats3)[1, "summary"], "min")
   expect_equal(collect(stats3)[2, "summary"], "max")
   expect_equal(collect(stats3)[3, "summary"], "55.1%")
-  expect_equal(collect(stats3)[3, "age"], "30.0")
+  expect_equal(collect(stats3)[3, "age"], "30")
 
   # SPARK-16425: SparkR summary() fails on column of type logical
   df <- withColumn(df, "boolean", df$age == 30)

http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/docs/sql-programming-guide.md
----------------------------------------------------------------------
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 5db60cc..a095263 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1553,6 +1553,7 @@ options.
 ## Upgrading From Spark SQL 2.2 to 2.3
 
   - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when 
the referenced columns only include the internal corrupt record column (named 
`_corrupt_record` by default). For example, 
`spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()`
 and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. 
Instead, you can cache or save the parsed results and then send the same query. 
For example, `val df = spark.read.schema(schema).json(file).cache()` and then 
`df.filter($"_corrupt_record".isNotNull).count()`.
+  - The `percentile_approx` function previously accepted numeric type input 
and output double type results. Now it supports date type, timestamp type and 
numeric types as input types. The result type is also changed to be the same as 
the input type, which is more reasonable for percentiles.
 
 ## Upgrading From Spark SQL 2.1 to 2.2
 

http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 7b81a0b..b7ce9a8 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1038,9 +1038,9 @@ class DataFrame(object):
         |   mean|               3.5| null|
         | stddev|2.1213203435596424| null|
         |    min|                 2|Alice|
-        |    25%|               5.0| null|
-        |    50%|               5.0| null|
-        |    75%|               5.0| null|
+        |    25%|                 5| null|
+        |    50%|                 5| null|
+        |    75%|                 5| null|
         |    max|                 5|  Bob|
         +-------+------------------+-----+
 
@@ -1050,8 +1050,8 @@ class DataFrame(object):
         +-------+---+-----+
         |  count|  2|    2|
         |    min|  2|Alice|
-        |    25%|5.0| null|
-        |    75%|5.0| null|
+        |    25%|  5| null|
+        |    75%|  5| null|
         |    max|  5|  Bob|
         +-------+---+-----+
 

http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 896c009..7facb9d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -85,7 +85,10 @@ case class ApproximatePercentile(
   private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int]
 
   override def inputTypes: Seq[AbstractDataType] = {
-    Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), 
IntegerType)
+    // Support NumericType, DateType and TimestampType since their internal 
types are all numeric,
+    // and can be easily cast to double for processing.
+    Seq(TypeCollection(NumericType, DateType, TimestampType),
+      TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType)
   }
 
   // Mark as lazy so that percentageExpression is not evaluated during tree 
transformation.
@@ -123,7 +126,15 @@ case class ApproximatePercentile(
     val value = child.eval(inputRow)
     // Ignore empty rows, for example: percentile_approx(null)
     if (value != null) {
-      buffer.add(value.asInstanceOf[Double])
+      // Convert the value to a double value
+      val doubleValue = child.dataType match {
+        case DateType => value.asInstanceOf[Int].toDouble
+        case TimestampType => value.asInstanceOf[Long].toDouble
+        case n: NumericType => 
n.numeric.toDouble(value.asInstanceOf[n.InternalType])
+        case other: DataType =>
+          throw new UnsupportedOperationException(s"Unexpected data type 
$other")
+      }
+      buffer.add(doubleValue)
     }
     buffer
   }
@@ -134,7 +145,20 @@ case class ApproximatePercentile(
   }
 
   override def eval(buffer: PercentileDigest): Any = {
-    val result = buffer.getPercentiles(percentages)
+    val doubleResult = buffer.getPercentiles(percentages)
+    val result = child.dataType match {
+      case DateType => doubleResult.map(_.toInt)
+      case TimestampType => doubleResult.map(_.toLong)
+      case ByteType => doubleResult.map(_.toByte)
+      case ShortType => doubleResult.map(_.toShort)
+      case IntegerType => doubleResult.map(_.toInt)
+      case LongType => doubleResult.map(_.toLong)
+      case FloatType => doubleResult.map(_.toFloat)
+      case DoubleType => doubleResult
+      case _: DecimalType => doubleResult.map(Decimal(_))
+      case other: DataType =>
+        throw new UnsupportedOperationException(s"Unexpected data type $other")
+    }
     if (result.length == 0) {
       null
     } else if (returnPercentileArray) {
@@ -155,8 +179,9 @@ case class ApproximatePercentile(
   // Returns null for empty inputs
   override def nullable: Boolean = true
 
+  // The result type is the same as the input type.
   override def dataType: DataType = {
-    if (returnPercentileArray) ArrayType(DoubleType, false) else DoubleType
+    if (returnPercentileArray) ArrayType(child.dataType, false) else 
child.dataType
   }
 
   override def prettyName: String = "percentile_approx"

http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
index fcb370a..84b3cc7 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, 
UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
BoundReference, Cast, CreateArray, DecimalLiteral, GenericInternalRow, Literal}
@@ -270,7 +270,6 @@ class ApproximatePercentileSuite extends SparkFunSuite {
         percentageExpression = percentageExpression,
         accuracyExpression = Literal(100))
 
-      val result = wrongPercentage.checkInputDataTypes()
       assert(
         wrongPercentage.checkInputDataTypes() match {
           case TypeCheckFailure(msg) if msg.contains("must be between 0.0 and 
1.0") => true
@@ -281,7 +280,6 @@ class ApproximatePercentileSuite extends SparkFunSuite {
 
   test("class ApproximatePercentile, automatically add type casting for 
parameters") {
     val testRelation = LocalRelation('a.int)
-    val analyzer = SimpleAnalyzer
 
     // Compatible accuracy types: Long type and decimal type
     val accuracyExpressions = Seq(Literal(1000L), DecimalLiteral(10000), 
Literal(123.0D))
@@ -299,7 +297,7 @@ class ApproximatePercentileSuite extends SparkFunSuite {
         analyzed match {
           case Alias(agg: ApproximatePercentile, _) =>
             assert(agg.resolved)
-            assert(agg.child.dataType == DoubleType)
+            assert(agg.child.dataType == IntegerType)
             assert(agg.percentageExpression.dataType == DoubleType ||
               agg.percentageExpression.dataType == ArrayType(DoubleType, 
containsNull = false))
             assert(agg.accuracyExpression.dataType == IntegerType)

http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
index 62a7534..1aea337 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
@@ -17,8 +17,11 @@
 
 package org.apache.spark.sql
 
+import java.sql.{Date, Timestamp}
+
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.test.SharedSQLContext
 
 /**
@@ -67,6 +70,30 @@ class ApproximatePercentileQuerySuite extends QueryTest with 
SharedSQLContext {
     }
   }
 
+  test("percentile_approx, different column types") {
+    withTempView(table) {
+      val intSeq = 1 to 1000
+      val data: Seq[(java.math.BigDecimal, Date, Timestamp)] = intSeq.map { i 
=>
+        (new java.math.BigDecimal(i), DateTimeUtils.toJavaDate(i), 
DateTimeUtils.toJavaTimestamp(i))
+      }
+      data.toDF("cdecimal", "cdate", 
"ctimestamp").createOrReplaceTempView(table)
+      checkAnswer(
+        spark.sql(
+          s"""SELECT
+             |  percentile_approx(cdecimal, array(0.25, 0.5, 0.75D)),
+             |  percentile_approx(cdate, array(0.25, 0.5, 0.75D)),
+             |  percentile_approx(ctimestamp, array(0.25, 0.5, 0.75D))
+             |FROM $table
+           """.stripMargin),
+        Row(
+          Seq("250.000000000000000000", "500.000000000000000000", 
"750.000000000000000000")
+              .map(i => new java.math.BigDecimal(i)),
+          Seq(250, 500, 750).map(DateTimeUtils.toJavaDate),
+          Seq(250, 500, 750).map(i => DateTimeUtils.toJavaTimestamp(i.toLong)))
+      )
+    }
+  }
+
   test("percentile_approx, multiple records with the minimum value in a 
partition") {
     withTempView(table) {
       spark.sparkContext.makeRDD(Seq(1, 1, 2, 1, 1, 3, 1, 1, 4, 1, 1, 5), 
4).toDF("col")
@@ -88,7 +115,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with 
SharedSQLContext {
       val accuracies = Array(1, 10, 100, 1000, 10000)
       val errors = accuracies.map { accuracy =>
         val df = spark.sql(s"SELECT percentile_approx(col, 0.25, $accuracy) 
FROM $table")
-        val approximatePercentile = df.collect().head.getDouble(0)
+        val approximatePercentile = df.collect().head.getInt(0)
         val error = Math.abs(approximatePercentile - expectedPercentile)
         error
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 1334164..6178661 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -803,9 +803,9 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
       Row("mean", null, "33.0", "178.0"),
       Row("stddev", null, "19.148542155126762", "11.547005383792516"),
       Row("min", "Alice", "16", "164"),
-      Row("25%", null, "24.0", "176.0"),
-      Row("50%", null, "24.0", "176.0"),
-      Row("75%", null, "32.0", "180.0"),
+      Row("25%", null, "24", "176"),
+      Row("50%", null, "24", "176"),
+      Row("75%", null, "32", "180"),
       Row("max", "David", "60", "192"))
 
     val emptySummaryResult = Seq(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to