This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6f8c620  [SPARK-35558] Optimizes for multi-quantile retrieval
6f8c620 is described below

commit 6f8c62047cea125d52af5dad7fb5ad3eadb7f7d0
Author: Alkis Polyzotis <alkis.polyzo...@databricks.com>
AuthorDate: Sat Jun 5 14:25:33 2021 -0500

    [SPARK-35558] Optimizes for multi-quantile retrieval
    
    ### What changes were proposed in this pull request?
    Optimizes the retrieval of approximate quantiles for an array of 
percentiles.
    * Adds an overload for QuantileSummaries.query that accepts an array of 
percentiles and optimizes the computation to do a single pass over the sketch 
and avoid redundant computation.
    * Modifies the ApproximatePercentiles operator to call into the new method.
    
    All formatting changes are the result of running ./dev/scalafmt
    
    ### Why are the changes needed?
    The existing implementation does repeated calls per input percentile 
resulting in redundant computation.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added unit tests for the new method.
    
    Closes #32700 from alkispoly-db/spark_35558_approx_quants_array.
    
    Authored-by: Alkis Polyzotis <alkis.polyzo...@databricks.com>
    Signed-off-by: Sean Owen <sro...@gmail.com>
---
 .../aggregate/ApproximatePercentile.scala          |  11 +--
 .../sql/catalyst/util/QuantileSummaries.scala      | 107 +++++++++++++++------
 .../sql/catalyst/util/QuantileSummariesSuite.scala |  79 +++++++++++----
 .../spark/sql/execution/stat/StatFunctions.scala   |   7 +-
 .../org/apache/spark/sql/DataFrameStatSuite.scala  |   2 +-
 5 files changed, 149 insertions(+), 57 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 38d8d7d..78e64bf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -261,19 +261,12 @@ object ApproximatePercentile {
      *   val Array(p25, median, p75) = 
percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75))
      * }}}
      */
-    def getPercentiles(percentages: Array[Double]): Array[Double] = {
+    def getPercentiles(percentages: Array[Double]): Seq[Double] = {
       if (!isCompressed) compress()
       if (summaries.count == 0 || percentages.length == 0) {
         Array.emptyDoubleArray
       } else {
-        val result = new Array[Double](percentages.length)
-        var i = 0
-        while (i < percentages.length) {
-          // Since summaries.count != 0, the query here never return None.
-          result(i) = summaries.query(percentages(i)).get
-          i += 1
-        }
-        result
+        summaries.query(percentages).get
       }
     }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
index addf140..e0cd613 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
@@ -229,46 +229,99 @@ class QuantileSummaries(
   }
 
   /**
-   * Runs a query for a given quantile.
+   * Finds the approximate quantile for a percentile, starting at a specific 
index in the summary.
+   * This is a helper method that is called as we are making a pass over the 
summary and a sorted
+   * sequence of input percentiles.
+   *
+   * @param index The point at which to start scanning the summary for an 
approximate value.
+   * @param minRankAtIndex The accumulated minimum rank at the given index.
+   * @param targetError Target error from the summary.
+   * @param percentile The percentile whose value is computed.
+   * @return A tuple (i, r, a) where: i is the updated index for the next 
call, r is the updated
+   *         rank at i, and a is the approximate quantile.
+   */
+  private def findApproxQuantile(
+      index: Int,
+      minRankAtIndex: Long,
+      targetError: Double,
+      percentile: Double): (Int, Long, Double) = {
+    var curSample = sampled(index)
+    val rank = math.ceil(percentile * count).toLong
+    var i = index
+    var minRank = minRankAtIndex
+    while (i < sampled.length - 1) {
+      val maxRank = minRank + curSample.delta
+      if (maxRank - targetError <= rank && rank <= minRank + targetError) {
+        return (i, minRank, curSample.value)
+      } else {
+        i += 1
+        curSample = sampled(i)
+        minRank += curSample.g
+      }
+    }
+    (sampled.length - 1, 0, sampled.last.value)
+  }
+
+  /**
+   * Runs a query for a given sequence of percentiles.
    * The result follows the approximation guarantees detailed above.
    * The query can only be run on a compressed summary: you need to call 
compress() before using
    * it.
    *
-   * @param quantile the target quantile
-   * @return
+   * @param percentiles the target percentiles
+   * @return the corresponding approximate quantiles, in the same order as the 
input
    */
-  def query(quantile: Double): Option[Double] = {
-    require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range 
[0.0, 1.0]")
-    require(headSampled.isEmpty,
+  def query(percentiles: Seq[Double]): Option[Seq[Double]] = {
+    percentiles.foreach(p =>
+      require(p >= 0 && p <= 1.0, "percentile should be in the range [0.0, 
1.0]"))
+    require(
+      headSampled.isEmpty,
       "Cannot operate on an uncompressed summary, call compress() first")
 
     if (sampled.isEmpty) return None
 
-    if (quantile <= relativeError) {
-      return Some(sampled.head.value)
-    }
+    val targetError = sampled.foldLeft(Long.MinValue)((currentMax, stats) =>
+      currentMax.max(stats.delta + stats.g)) / 2
 
-    if (quantile >= 1 - relativeError) {
-      return Some(sampled.last.value)
-    }
-
-    // Target rank
-    val rank = math.ceil(quantile * count).toLong
-    val targetError = sampled.map(s => s.delta + s.g).max / 2
+    // Index to track the current sample
+    var index = 0
     // Minimum rank at current sample
-    var minRank = 0L
-    var i = 0
-    while (i < sampled.length - 1) {
-      val curSample = sampled(i)
-      minRank += curSample.g
-      val maxRank = minRank + curSample.delta
-      if (maxRank - targetError <= rank && rank <= minRank + targetError) {
-        return Some(curSample.value)
-      }
-      i += 1
+    var minRank = sampled(0).g
+
+    val sortedPercentiles = percentiles.zipWithIndex.sortBy(_._1)
+    val result = Array.fill(percentiles.length)(0.0)
+    sortedPercentiles.foreach {
+      case (percentile, pos) =>
+        if (percentile <= relativeError) {
+          result(pos) = sampled.head.value
+        } else if (percentile >= 1 - relativeError) {
+          result(pos) = sampled.last.value
+        } else {
+          val (newIndex, newMinRank, approxQuantile) =
+            findApproxQuantile(index, minRank, targetError, percentile)
+          index = newIndex
+          minRank = newMinRank
+          result(pos) = approxQuantile
+        }
     }
-    Some(sampled.last.value)
+    Some(result)
   }
+
+  /**
+   * Runs a query for a given percentile.
+   * The result follows the approximation guarantees detailed above.
+   * The query can only be run on a compressed summary: you need to call 
compress() before using
+   * it.
+   *
+   * @param percentile the target percentile
+   * @return the corresponding approximate quantile
+   */
+  def query(percentile: Double): Option[Double] =
+    query(Seq(percentile)) match {
+      case Some(approxSeq) if approxSeq.nonEmpty => Some(approxSeq.head)
+      case _ => None
+    }
+
 }
 
 object QuantileSummaries {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala
index e53d0bb..018db3aed 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.util
 
+import scala.collection.mutable.ArrayBuffer
 import scala.util.Random
 
 import org.apache.spark.SparkFunSuite
@@ -54,25 +55,51 @@ class QuantileSummariesSuite extends SparkFunSuite {
     summary
   }
 
-  private def checkQuantile(quant: Double, data: Seq[Double], summary: 
QuantileSummaries): Unit = {
+  private def validateQuantileApproximation(
+      approx: Double,
+      percentile: Double,
+      data: Seq[Double],
+      summary: QuantileSummaries): Unit = {
+    assert(data.nonEmpty)
+
+    val rankOfValue = data.count(_ <= approx)
+    val rankOfPreValue = data.count(_ < approx)
+    // `rankOfValue` is the last position of the quantile value. If the input 
repeats the value
+    // chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 2, 
then it's
+    // improper to choose the last position as its rank. Instead, we get the 
rank by averaging
+    // `rankOfValue` and `rankOfPreValue`.
+    val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0)
+    val lower = math.floor((percentile - summary.relativeError) * data.size)
+    val upper = math.ceil((percentile + summary.relativeError) * data.size)
+    val msg =
+      s"$rank not in [$lower $upper], requested percentile: $percentile, 
approx returned: $approx"
+    assert(rank >= lower, msg)
+    assert(rank <= upper, msg)
+  }
+
+  private def checkQuantile(
+      percentile: Double,
+      data: Seq[Double],
+      summary: QuantileSummaries): Unit = {
     if (data.nonEmpty) {
-      val approx = summary.query(quant).get
-      // Get the rank of the approximation.
-      val rankOfValue = data.count(_ <= approx)
-      val rankOfPreValue = data.count(_ < approx)
-      // `rankOfValue` is the last position of the quantile value. If the 
input repeats the value
-      // chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 
2, then it's
-      // improper to choose the last position as its rank. Instead, we get the 
rank by averaging
-      // `rankOfValue` and `rankOfPreValue`.
-      val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0)
-      val lower = math.floor((quant - summary.relativeError) * data.size)
-      val upper = math.ceil((quant + summary.relativeError) * data.size)
-      val msg =
-        s"$rank not in [$lower $upper], requested quantile: $quant, approx 
returned: $approx"
-      assert(rank >= lower, msg)
-      assert(rank <= upper, msg)
+      val approx = summary.query(percentile).get
+      validateQuantileApproximation(approx, percentile, data, summary)
+    } else {
+      assert(summary.query(percentile).isEmpty)
+    }
+  }
+
+  private def checkQuantiles(
+      percentiles: Seq[Double],
+      data: Seq[Double],
+      summary: QuantileSummaries): Unit = {
+    if (data.nonEmpty) {
+      val approx = summary.query(percentiles).get
+      for ((q, a) <- percentiles zip approx) {
+        validateQuantileApproximation(a, q, data, summary)
+      }
     } else {
-      assert(summary.query(quant).isEmpty)
+      assert(summary.query(percentiles).isEmpty)
     }
   }
 
@@ -98,6 +125,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
       checkQuantile(0.5, data, s)
       checkQuantile(0.1, data, s)
       checkQuantile(0.001, data, s)
+      checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
+      checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
     }
 
     test(s"Some quantile values with epsi=$epsi and seq=$seq_name, 
compression=$compression " +
@@ -109,6 +138,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
       checkQuantile(0.5, data, s)
       checkQuantile(0.1, data, s)
       checkQuantile(0.001, data, s)
+      checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
+      checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
     }
 
     test(s"Tests on empty data with epsi=$epsi and seq=$seq_name, 
compression=$compression") {
@@ -121,6 +152,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
       checkQuantile(0.5, emptyData, s)
       checkQuantile(0.1, emptyData, s)
       checkQuantile(0.001, emptyData, s)
+      checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), emptyData, s)
+      checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), emptyData, s)
     }
   }
 
@@ -149,6 +182,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
       checkQuantile(0.5, data, s)
       checkQuantile(0.1, data, s)
       checkQuantile(0.001, data, s)
+      checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
+      checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
     }
 
     val (data11, data12) = {
@@ -168,6 +203,8 @@ class QuantileSummariesSuite extends SparkFunSuite {
       checkQuantile(0.5, data, s)
       checkQuantile(0.1, data, s)
       checkQuantile(0.001, data, s)
+      checkQuantiles(Seq(0.001, 0.1, 0.5, 0.9, 0.9999), data, s)
+      checkQuantiles(Seq(0.9999, 0.9, 0.5, 0.1, 0.001), data, s)
     }
 
     // length of data21 is 4 * length of data22
@@ -181,10 +218,14 @@ class QuantileSummariesSuite extends SparkFunSuite {
       val s2 = buildSummary(data22, epsi, compression)
       val s = s1.merge(s2)
       // Check all quantiles
+      val percentiles = ArrayBuffer[Double]()
       for (queryRank <- 1 to n) {
-        val queryQuantile = queryRank.toDouble / n.toDouble
-        checkQuantile(queryQuantile, data, s)
+        val percentile = queryRank.toDouble / n.toDouble
+        checkQuantile(percentile, data, s)
+        percentiles += percentile
       }
+      checkQuantiles(percentiles.toSeq, data, s)
+      checkQuantiles(percentiles.reverse.toSeq, data, s)
     }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 0a9954e6..5dc0ff0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -102,7 +102,12 @@ object StatFunctions extends Logging {
     }
     val summaries = df.select(columns: 
_*).rdd.treeAggregate(emptySummaries)(apply, merge)
 
-    summaries.map { summary => probabilities.flatMap(summary.query) }
+    summaries.map {
+      summary => summary.query(probabilities) match {
+        case Some(q) => q
+        case None => Seq()
+      }
+    }
   }
 
   /** Calculate the Pearson Correlation Coefficient for the given columns */
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index cdd2568..79ab3cd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -204,7 +204,7 @@ class DataFrameStatSuite extends QueryTest with 
SharedSparkSession {
     val e = intercept[IllegalArgumentException] {
       df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2, -0.1), 
epsilons.head)
     }
-    assert(e.getMessage.contains("quantile should be in the range [0.0, 1.0]"))
+    assert(e.getMessage.contains("percentile should be in the range [0.0, 
1.0]"))
 
     // relativeError should be non-negative
     val e2 = intercept[IllegalArgumentException] {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to