Repository: spark
Updated Branches:
  refs/heads/master 6960a7938 -> 73d92b00b


[SPARK-9018] [MLLIB] add stopwatches

Add stopwatches for easy instrumentation of MLlib algorithms. This is based on 
the `TimeTracker` used in decision trees. The distributed version uses Spark 
accumulator. jkbradley

Author: Xiangrui Meng <m...@databricks.com>

Closes #7415 from mengxr/SPARK-9018 and squashes the following commits:

40b4347 [Xiangrui Meng] == -> ===
c477745 [Xiangrui Meng] address Joseph's comments
f981a49 [Xiangrui Meng] add stopwatches


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/73d92b00
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/73d92b00
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/73d92b00

Branch: refs/heads/master
Commit: 73d92b00b9a6f5dfc2f8116447d17b381cd74f80
Parents: 6960a79
Author: Xiangrui Meng <m...@databricks.com>
Authored: Wed Jul 15 21:02:42 2015 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Wed Jul 15 21:02:42 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/util/stopwatches.scala  | 151 +++++++++++++++++++
 .../apache/spark/ml/util/StopwatchSuite.scala   | 109 +++++++++++++
 2 files changed, 260 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/73d92b00/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
new file mode 100644
index 0000000..5fdf878
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import scala.collection.mutable
+
+import org.apache.spark.{Accumulator, SparkContext}
+
+/**
+ * Abstract class for stopwatches.
+ */
+private[spark] abstract class Stopwatch extends Serializable {
+
+  @transient private var running: Boolean = false
+  private var startTime: Long = _
+
+  /**
+   * Name of the stopwatch.
+   */
+  val name: String
+
+  /**
+   * Starts the stopwatch.
+   * Throws an exception if the stopwatch is already running.
+   */
+  def start(): Unit = {
+    assume(!running, "start() called but the stopwatch is already running.")
+    running = true
+    startTime = now
+  }
+
+  /**
+   * Stops the stopwatch and returns the duration of the last session in 
milliseconds.
+   * Throws an exception if the stopwatch is not running.
+   */
+  def stop(): Long = {
+    assume(running, "stop() called but the stopwatch is not running.")
+    val duration = now - startTime
+    add(duration)
+    running = false
+    duration
+  }
+
+  /**
+   * Checks whether the stopwatch is running.
+   */
+  def isRunning: Boolean = running
+
+  /**
+   * Returns total elapsed time in milliseconds, not counting the current 
session if the stopwatch
+   * is running.
+   */
+  def elapsed(): Long
+
+  /**
+   * Gets the current time in milliseconds.
+   */
+  protected def now: Long = System.currentTimeMillis()
+
+  /**
+   * Adds input duration to total elapsed time.
+   */
+  protected def add(duration: Long): Unit
+}
+
+/**
+ * A local [[Stopwatch]].
+ */
+private[spark] class LocalStopwatch(override val name: String) extends 
Stopwatch {
+
+  private var elapsedTime: Long = 0L
+
+  override def elapsed(): Long = elapsedTime
+
+  override protected def add(duration: Long): Unit = {
+    elapsedTime += duration
+  }
+}
+
+/**
+ * A distributed [[Stopwatch]] using Spark accumulator.
+ * @param sc SparkContext
+ */
+private[spark] class DistributedStopwatch(
+    sc: SparkContext,
+    override val name: String) extends Stopwatch {
+
+  private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, 
s"DistributedStopwatch($name)")
+
+  override def elapsed(): Long = elapsedTime.value
+
+  override protected def add(duration: Long): Unit = {
+    elapsedTime += duration
+  }
+}
+
+/**
+ * A multiple stopwatch that contains local and distributed stopwatches.
+ * @param sc SparkContext
+ */
+private[spark] class MultiStopwatch(@transient private val sc: SparkContext) 
extends Serializable {
+
+  private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty
+
+  /**
+   * Adds a local stopwatch.
+   * @param name stopwatch name
+   */
+  def addLocal(name: String): this.type = {
+    require(!stopwatches.contains(name), s"Stopwatch with name $name already 
exists.")
+    stopwatches(name) = new LocalStopwatch(name)
+    this
+  }
+
+  /**
+   * Adds a distributed stopwatch.
+   * @param name stopwatch name
+   */
+  def addDistributed(name: String): this.type = {
+    require(!stopwatches.contains(name), s"Stopwatch with name $name already 
exists.")
+    stopwatches(name) = new DistributedStopwatch(sc, name)
+    this
+  }
+
+  /**
+   * Gets a stopwatch.
+   * @param name stopwatch name
+   */
+  def apply(name: String): Stopwatch = stopwatches(name)
+
+  override def toString: String = {
+    stopwatches.values.toArray.sortBy(_.name)
+      .map(c => s"  ${c.name}: ${c.elapsed()}ms")
+      .mkString("{\n", ",\n", "\n}")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73d92b00/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
new file mode 100644
index 0000000..8df6617
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  private def testStopwatchOnDriver(sw: Stopwatch): Unit = {
+    assert(sw.name === "sw")
+    assert(sw.elapsed() === 0L)
+    assert(!sw.isRunning)
+    intercept[AssertionError] {
+      sw.stop()
+    }
+    sw.start()
+    Thread.sleep(50)
+    val duration = sw.stop()
+    assert(duration >= 50 && duration < 100) // using a loose upper bound
+    val elapsed = sw.elapsed()
+    assert(elapsed === duration)
+    sw.start()
+    Thread.sleep(50)
+    val duration2 = sw.stop()
+    assert(duration2 >= 50 && duration2 < 100)
+    val elapsed2 = sw.elapsed()
+    assert(elapsed2 === duration + duration2)
+    sw.start()
+    assert(sw.isRunning)
+    intercept[AssertionError] {
+      sw.start()
+    }
+  }
+
+  test("LocalStopwatch") {
+    val sw = new LocalStopwatch("sw")
+    testStopwatchOnDriver(sw)
+  }
+
+  test("DistributedStopwatch on driver") {
+    val sw = new DistributedStopwatch(sc, "sw")
+    testStopwatchOnDriver(sw)
+  }
+
+  test("DistributedStopwatch on executors") {
+    val sw = new DistributedStopwatch(sc, "sw")
+    val rdd = sc.parallelize(0 until 4, 4)
+    rdd.foreach { i =>
+      sw.start()
+      Thread.sleep(50)
+      sw.stop()
+    }
+    assert(!sw.isRunning)
+    val elapsed = sw.elapsed()
+    assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound
+  }
+
+  test("MultiStopwatch") {
+    val sw = new MultiStopwatch(sc)
+      .addLocal("local")
+      .addDistributed("spark")
+    assert(sw("local").name === "local")
+    assert(sw("spark").name === "spark")
+    intercept[NoSuchElementException] {
+      sw("some")
+    }
+    assert(sw.toString === "{\n  local: 0ms,\n  spark: 0ms\n}")
+    sw("local").start()
+    sw("spark").start()
+    Thread.sleep(50)
+    sw("local").stop()
+    Thread.sleep(50)
+    sw("spark").stop()
+    val localElapsed = sw("local").elapsed()
+    val sparkElapsed = sw("spark").elapsed()
+    assert(localElapsed >= 50 && localElapsed < 100)
+    assert(sparkElapsed >= 100 && sparkElapsed < 200)
+    assert(sw.toString ===
+      s"{\n  local: ${localElapsed}ms,\n  spark: ${sparkElapsed}ms\n}")
+    val rdd = sc.parallelize(0 until 4, 4)
+    rdd.foreach { i =>
+      sw("local").start()
+      sw("spark").start()
+      Thread.sleep(50)
+      sw("spark").stop()
+      sw("local").stop()
+    }
+    val localElapsed2 = sw("local").elapsed()
+    assert(localElapsed2 === localElapsed)
+    val sparkElapsed2 = sw("spark").elapsed()
+    assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to