Repository: spark Updated Branches: refs/heads/master 20adf9aa1 -> 365a29bdb
[SPARK-22100][SQL] Make percentile_approx support date/timestamp type and change the output type to be the same as input type ## What changes were proposed in this pull request? The `percentile_approx` function previously accepted numeric type input and output double type results. But since all numeric types, date and timestamp types are represented as numerics internally, `percentile_approx` can support them easily. After this PR, it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. This change is also required when we generate equi-height histograms for these types. ## How was this patch tested? Added a new test and modified some existing tests. Author: Zhenhua Wang <wangzhen...@huawei.com> Closes #19321 from wzhfy/approx_percentile_support_types. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/365a29bd Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/365a29bd Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/365a29bd Branch: refs/heads/master Commit: 365a29bdbfd18aae4b5374157dc1d2abfc64eb0e Parents: 20adf9a Author: Zhenhua Wang <wangzhen...@huawei.com> Authored: Mon Sep 25 09:28:42 2017 -0700 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Mon Sep 25 09:28:42 2017 -0700 ---------------------------------------------------------------------- R/pkg/tests/fulltests/test_sparkSQL.R | 4 +-- docs/sql-programming-guide.md | 1 + python/pyspark/sql/dataframe.py | 10 +++--- .../aggregate/ApproximatePercentile.scala | 33 +++++++++++++++++--- .../aggregate/ApproximatePercentileSuite.scala | 6 ++-- .../sql/ApproximatePercentileQuerySuite.scala | 29 ++++++++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++-- 7 files changed, 70 insertions(+), 19 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/R/pkg/tests/fulltests/test_sparkSQL.R ---------------------------------------------------------------------- diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 4d1010e..4e62be9 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2538,14 +2538,14 @@ test_that("describe() and summary() on a DataFrame", { stats2 <- summary(df) expect_equal(collect(stats2)[5, "summary"], "25%") - expect_equal(collect(stats2)[5, "age"], "30.0") + expect_equal(collect(stats2)[5, "age"], "30") stats3 <- summary(df, "min", "max", "55.1%") expect_equal(collect(stats3)[1, "summary"], "min") expect_equal(collect(stats3)[2, "summary"], "max") expect_equal(collect(stats3)[3, "summary"], "55.1%") - expect_equal(collect(stats3)[3, "age"], "30.0") + expect_equal(collect(stats3)[3, "age"], "30") # SPARK-16425: SparkR summary() fails on column of type logical df <- withColumn(df, "boolean", df$age == 30) http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/docs/sql-programming-guide.md ---------------------------------------------------------------------- diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5db60cc..a095263 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1553,6 +1553,7 @@ options. ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. + - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. ## Upgrading From Spark SQL 2.1 to 2.2 http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7b81a0b..b7ce9a8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1038,9 +1038,9 @@ class DataFrame(object): | mean| 3.5| null| | stddev|2.1213203435596424| null| | min| 2|Alice| - | 25%| 5.0| null| - | 50%| 5.0| null| - | 75%| 5.0| null| + | 25%| 5| null| + | 50%| 5| null| + | 75%| 5| null| | max| 5| Bob| +-------+------------------+-----+ @@ -1050,8 +1050,8 @@ class DataFrame(object): +-------+---+-----+ | count| 2| 2| | min| 2|Alice| - | 25%|5.0| null| - | 75%|5.0| null| + | 25%| 5| null| + | 75%| 5| null| | max| 5| Bob| +-------+---+-----+ http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 896c009..7facb9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -85,7 +85,10 @@ case class ApproximatePercentile( private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int] override def inputTypes: Seq[AbstractDataType] = { - Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType) + // Support NumericType, DateType and TimestampType since their internal types are all numeric, + // and can be easily cast to double for processing. + Seq(TypeCollection(NumericType, DateType, TimestampType), + TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType) } // Mark as lazy so that percentageExpression is not evaluated during tree transformation. @@ -123,7 +126,15 @@ case class ApproximatePercentile( val value = child.eval(inputRow) // Ignore empty rows, for example: percentile_approx(null) if (value != null) { - buffer.add(value.asInstanceOf[Double]) + // Convert the value to a double value + val doubleValue = child.dataType match { + case DateType => value.asInstanceOf[Int].toDouble + case TimestampType => value.asInstanceOf[Long].toDouble + case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType]) + case other: DataType => + throw new UnsupportedOperationException(s"Unexpected data type $other") + } + buffer.add(doubleValue) } buffer } @@ -134,7 +145,20 @@ case class ApproximatePercentile( } override def eval(buffer: PercentileDigest): Any = { - val result = buffer.getPercentiles(percentages) + val doubleResult = buffer.getPercentiles(percentages) + val result = child.dataType match { + case DateType => doubleResult.map(_.toInt) + case TimestampType => doubleResult.map(_.toLong) + case ByteType => doubleResult.map(_.toByte) + case ShortType => doubleResult.map(_.toShort) + case IntegerType => doubleResult.map(_.toInt) + case LongType => doubleResult.map(_.toLong) + case FloatType => doubleResult.map(_.toFloat) + case DoubleType => doubleResult + case _: DecimalType => doubleResult.map(Decimal(_)) + case other: DataType => + throw new UnsupportedOperationException(s"Unexpected data type $other") + } if (result.length == 0) { null } else if (returnPercentileArray) { @@ -155,8 +179,9 @@ case class ApproximatePercentile( // Returns null for empty inputs override def nullable: Boolean = true + // The result type is the same as the input type. override def dataType: DataType = { - if (returnPercentileArray) ArrayType(DoubleType, false) else DoubleType + if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType } override def prettyName: String = "percentile_approx" http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala index fcb370a..84b3cc7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericInternalRow, Literal} @@ -270,7 +270,6 @@ class ApproximatePercentileSuite extends SparkFunSuite { percentageExpression = percentageExpression, accuracyExpression = Literal(100)) - val result = wrongPercentage.checkInputDataTypes() assert( wrongPercentage.checkInputDataTypes() match { case TypeCheckFailure(msg) if msg.contains("must be between 0.0 and 1.0") => true @@ -281,7 +280,6 @@ class ApproximatePercentileSuite extends SparkFunSuite { test("class ApproximatePercentile, automatically add type casting for parameters") { val testRelation = LocalRelation('a.int) - val analyzer = SimpleAnalyzer // Compatible accuracy types: Long type and decimal type val accuracyExpressions = Seq(Literal(1000L), DecimalLiteral(10000), Literal(123.0D)) @@ -299,7 +297,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { analyzed match { case Alias(agg: ApproximatePercentile, _) => assert(agg.resolved) - assert(agg.child.dataType == DoubleType) + assert(agg.child.dataType == IntegerType) assert(agg.percentageExpression.dataType == DoubleType || agg.percentageExpression.dataType == ArrayType(DoubleType, containsNull = false)) assert(agg.accuracyExpression.dataType == IntegerType) http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 62a7534..1aea337 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql +import java.sql.{Date, Timestamp} + import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.test.SharedSQLContext /** @@ -67,6 +70,30 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { } } + test("percentile_approx, different column types") { + withTempView(table) { + val intSeq = 1 to 1000 + val data: Seq[(java.math.BigDecimal, Date, Timestamp)] = intSeq.map { i => + (new java.math.BigDecimal(i), DateTimeUtils.toJavaDate(i), DateTimeUtils.toJavaTimestamp(i)) + } + data.toDF("cdecimal", "cdate", "ctimestamp").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s"""SELECT + | percentile_approx(cdecimal, array(0.25, 0.5, 0.75D)), + | percentile_approx(cdate, array(0.25, 0.5, 0.75D)), + | percentile_approx(ctimestamp, array(0.25, 0.5, 0.75D)) + |FROM $table + """.stripMargin), + Row( + Seq("250.000000000000000000", "500.000000000000000000", "750.000000000000000000") + .map(i => new java.math.BigDecimal(i)), + Seq(250, 500, 750).map(DateTimeUtils.toJavaDate), + Seq(250, 500, 750).map(i => DateTimeUtils.toJavaTimestamp(i.toLong))) + ) + } + } + test("percentile_approx, multiple records with the minimum value in a partition") { withTempView(table) { spark.sparkContext.makeRDD(Seq(1, 1, 2, 1, 1, 3, 1, 1, 4, 1, 1, 5), 4).toDF("col") @@ -88,7 +115,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { val accuracies = Array(1, 10, 100, 1000, 10000) val errors = accuracies.map { accuracy => val df = spark.sql(s"SELECT percentile_approx(col, 0.25, $accuracy) FROM $table") - val approximatePercentile = df.collect().head.getDouble(0) + val approximatePercentile = df.collect().head.getInt(0) val error = Math.abs(approximatePercentile - expectedPercentile) error } http://git-wip-us.apache.org/repos/asf/spark/blob/365a29bd/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1334164..6178661 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -803,9 +803,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("mean", null, "33.0", "178.0"), Row("stddev", null, "19.148542155126762", "11.547005383792516"), Row("min", "Alice", "16", "164"), - Row("25%", null, "24.0", "176.0"), - Row("50%", null, "24.0", "176.0"), - Row("75%", null, "32.0", "180.0"), + Row("25%", null, "24", "176"), + Row("50%", null, "24", "176"), + Row("75%", null, "32", "180"), Row("max", "David", "60", "192")) val emptySummaryResult = Seq( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org