Repository: spark
Updated Branches:
  refs/heads/branch-2.2 67c60d78e -> 397f90421


[SPARK-21597][SS] Fix a potential overflow issue in EventTimeStats

## What changes were proposed in this pull request?

This PR fixed a potential overflow issue in EventTimeStats.

## How was this patch tested?

The new unit tests

Author: Shixiong Zhu <shixi...@databricks.com>

Closes #18803 from zsxwing/avg.

(cherry picked from commit 7f63e85b47a93434030482160e88fe63bf9cff4e)
Signed-off-by: Shixiong Zhu <shixi...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/397f9042
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/397f9042
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/397f9042

Branch: refs/heads/branch-2.2
Commit: 397f904219e7617386144aba87998a057bde02e3
Parents: 67c60d7
Author: Shixiong Zhu <shixi...@databricks.com>
Authored: Wed Aug 2 10:59:59 2017 -0700
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Wed Aug 2 11:00:09 2017 -0700

----------------------------------------------------------------------
 .../streaming/EventTimeWatermarkExec.scala      | 10 ++---
 .../execution/streaming/ProgressReporter.scala  |  2 +-
 .../sql/streaming/EventTimeWatermarkSuite.scala | 41 +++++++++++++++++++-
 3 files changed, 44 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/397f9042/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
index 25cf609..55e7508 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
@@ -27,27 +27,25 @@ import org.apache.spark.unsafe.types.CalendarInterval
 import org.apache.spark.util.AccumulatorV2
 
 /** Class for collecting event time stats with an accumulator */
-case class EventTimeStats(var max: Long, var min: Long, var sum: Long, var 
count: Long) {
+case class EventTimeStats(var max: Long, var min: Long, var avg: Double, var 
count: Long) {
   def add(eventTime: Long): Unit = {
     this.max = math.max(this.max, eventTime)
     this.min = math.min(this.min, eventTime)
-    this.sum += eventTime
     this.count += 1
+    this.avg += (eventTime - avg) / count
   }
 
   def merge(that: EventTimeStats): Unit = {
     this.max = math.max(this.max, that.max)
     this.min = math.min(this.min, that.min)
-    this.sum += that.sum
     this.count += that.count
+    this.avg += (that.avg - this.avg) * that.count / this.count
   }
-
-  def avg: Long = sum / count
 }
 
 object EventTimeStats {
   def zero: EventTimeStats = EventTimeStats(
-    max = Long.MinValue, min = Long.MaxValue, sum = 0L, count = 0L)
+    max = Long.MinValue, min = Long.MaxValue, avg = 0.0, count = 0L)
 }
 
 /** Accumulator that collects stats on event time in a batch. */

http://git-wip-us.apache.org/repos/asf/spark/blob/397f9042/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index a4e4ca8..db46fcd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -267,7 +267,7 @@ trait ProgressReporter extends Logging {
         Map(
           "max" -> stats.max,
           "min" -> stats.min,
-          "avg" -> stats.avg).mapValues(formatTimestamp)
+          "avg" -> stats.avg.toLong).mapValues(formatTimestamp)
     }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp
 
     ExecutionStats(numInputRows, stateOperators, eventTimeStats)

http://git-wip-us.apache.org/repos/asf/spark/blob/397f9042/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
index 1b60a06..552911f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
@@ -21,7 +21,7 @@ import java.{util => ju}
 import java.text.SimpleDateFormat
 import java.util.Date
 
-import org.scalatest.BeforeAndAfter
+import org.scalatest.{BeforeAndAfter, Matchers}
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.AnalysisException
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.functions.{count, window}
 import org.apache.spark.sql.streaming.OutputMode._
 
-class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with 
Logging {
+class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with 
Matchers with Logging {
 
   import testImplicits._
 
@@ -38,6 +38,43 @@ class EventTimeWatermarkSuite extends StreamTest with 
BeforeAndAfter with Loggin
     sqlContext.streams.active.foreach(_.stop())
   }
 
+  test("EventTimeStats") {
+    val epsilon = 10E-6
+
+    val stats = EventTimeStats(max = 100, min = 10, avg = 20.0, count = 5)
+    stats.add(80L)
+    stats.max should be (100)
+    stats.min should be (10)
+    stats.avg should be (30.0 +- epsilon)
+    stats.count should be (6)
+
+    val stats2 = EventTimeStats(80L, 5L, 15.0, 4)
+    stats.merge(stats2)
+    stats.max should be (100)
+    stats.min should be (5)
+    stats.avg should be (24.0 +- epsilon)
+    stats.count should be (10)
+  }
+
+  test("EventTimeStats: avg on large values") {
+    val epsilon = 10E-6
+    val largeValue = 10000000000L // 10B
+    // Make sure `largeValue` will cause overflow if we use a Long sum to calc 
avg.
+    assert(largeValue * largeValue != BigInt(largeValue) * BigInt(largeValue))
+    val stats =
+      EventTimeStats(max = largeValue, min = largeValue, avg = largeValue, 
count = largeValue - 1)
+    stats.add(largeValue)
+    stats.avg should be (largeValue.toDouble +- epsilon)
+
+    val stats2 = EventTimeStats(
+      max = largeValue + 1,
+      min = largeValue,
+      avg = largeValue + 1,
+      count = largeValue)
+    stats.merge(stats2)
+    stats.avg should be ((largeValue + 0.5) +- epsilon)
+  }
+
   test("error on bad column") {
     val inputData = MemoryStream[Int].toDF()
     val e = intercept[AnalysisException] {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to