Repository: spark
Updated Branches:
  refs/heads/master 594ac4f7b -> 5264164a6


[SPARK-24648][SQL] SqlMetrics should be threadsafe

Use LongAdder to make SQLMetrics thread safe.

## What changes were proposed in this pull request?
Replace += with LongAdder.add() for concurrent counting

## How was this patch tested?
Unit tests with local threads

Author: Stacy Kerkela <stacy.kerk...@databricks.com>

Closes #21634 from dbkerkela/sqlmetrics-concurrency-stacy.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5264164a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5264164a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5264164a

Branch: refs/heads/master
Commit: 5264164a67df498b73facae207eda12ee133be7d
Parents: 594ac4f
Author: Stacy Kerkela <stacy.kerk...@databricks.com>
Authored: Mon Jun 25 23:41:39 2018 +0200
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Mon Jun 25 23:41:39 2018 +0200

----------------------------------------------------------------------
 .../spark/sql/execution/metric/SQLMetrics.scala | 33 +++++++++++-------
 .../sql/execution/metric/SQLMetricsSuite.scala  | 36 +++++++++++++++++++-
 2 files changed, 55 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5264164a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 77b9078..b4f0ae1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric
 
 import java.text.NumberFormat
 import java.util.Locale
+import java.util.concurrent.atomic.LongAdder
 
 import org.apache.spark.SparkContext
 import org.apache.spark.scheduler.AccumulableInfo
@@ -32,40 +33,45 @@ import org.apache.spark.util.{AccumulatorContext, 
AccumulatorV2, Utils}
  * on the driver side must be explicitly posted using 
[[SQLMetrics.postDriverMetricUpdates()]].
  */
 class SQLMetric(val metricType: String, initValue: Long = 0L) extends 
AccumulatorV2[Long, Long] {
+
   // This is a workaround for SPARK-11013.
   // We may use -1 as initial value of the accumulator, if the accumulator is 
valid, we will
   // update it at the end of task and the value will be at least 0. Then we 
can filter out the -1
   // values before calculate max, min, etc.
-  private[this] var _value = initValue
-  private var _zeroValue = initValue
+  private[this] val _value = new LongAdder
+  private val _zeroValue = initValue
+  _value.add(initValue)
 
   override def copy(): SQLMetric = {
-    val newAcc = new SQLMetric(metricType, _value)
-    newAcc._zeroValue = initValue
+    val newAcc = new SQLMetric(metricType, initValue)
+    newAcc.add(_value.sum())
     newAcc
   }
 
-  override def reset(): Unit = _value = _zeroValue
+  override def reset(): Unit = this.set(_zeroValue)
 
   override def merge(other: AccumulatorV2[Long, Long]): Unit = other match {
-    case o: SQLMetric => _value += o.value
+    case o: SQLMetric => _value.add(o.value)
     case _ => throw new UnsupportedOperationException(
       s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
   }
 
-  override def isZero(): Boolean = _value == _zeroValue
+  override def isZero(): Boolean = _value.sum() == _zeroValue
 
-  override def add(v: Long): Unit = _value += v
+  override def add(v: Long): Unit = _value.add(v)
 
   // We can set a double value to `SQLMetric` which stores only long value, if 
it is
   // average metrics.
   def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v)
 
-  def set(v: Long): Unit = _value = v
+  def set(v: Long): Unit = {
+    _value.reset()
+    _value.add(v)
+  }
 
-  def +=(v: Long): Unit = _value += v
+  def +=(v: Long): Unit = _value.add(v)
 
-  override def value: Long = _value
+  override def value: Long = _value.sum()
 
   // Provide special identifier as metadata so we can tell that this is a 
`SQLMetric` later
   override def toInfo(update: Option[Any], value: Option[Any]): 
AccumulableInfo = {
@@ -153,7 +159,7 @@ object SQLMetrics {
           Seq.fill(3)(0L)
         } else {
           val sorted = validValues.sorted
-          Seq(sorted(0), sorted(validValues.length / 2), 
sorted(validValues.length - 1))
+          Seq(sorted.head, sorted(validValues.length / 2), 
sorted(validValues.length - 1))
         }
         metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric))
       }
@@ -173,7 +179,8 @@ object SQLMetrics {
           Seq.fill(4)(0L)
         } else {
           val sorted = validValues.sorted
-          Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), 
sorted(validValues.length - 1))
+          Seq(sorted.sum, sorted.head, sorted(validValues.length / 2),
+            sorted(validValues.length - 1))
         }
         metric.map(strFormat)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/5264164a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index a3a3f38..8263c9c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.metric
 
 import java.io.File
 
+import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future}
 import scala.util.Random
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.execution.ui.SQLAppStatusStore
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
@@ -504,4 +504,38 @@ class SQLMetricsSuite extends SparkFunSuite with 
SQLMetricsTestUtils with Shared
   test("writing data out metrics with dynamic partition: parquet") {
     testMetricsDynamicPartition("parquet", "parquet", "t1")
   }
+
+  test("writing metrics from single thread") {
+    val nAdds = 10
+    val acc = new SQLMetric("test", -10)
+    assert(acc.isZero())
+    acc.set(0)
+    for (i <- 1 to nAdds) acc.add(1)
+    assert(!acc.isZero())
+    assert(nAdds === acc.value)
+    acc.reset()
+    assert(acc.isZero())
+  }
+
+  test("writing metrics from multiple threads") {
+    implicit val ec: ExecutionContextExecutor = ExecutionContext.global
+    val nFutures = 1000
+    val nAdds = 100
+    val acc = new SQLMetric("test", -10)
+    assert(acc.isZero() === true)
+    acc.set(0)
+    val l = for ( i <- 1 to nFutures ) yield {
+      Future {
+        for (j <- 1 to nAdds) acc.add(1)
+        i
+      }
+    }
+    for (futures <- Future.sequence(l)) {
+      assert(nFutures === futures.length)
+      assert(!acc.isZero())
+      assert(nFutures * nAdds === acc.value)
+      acc.reset()
+      assert(acc.isZero())
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to