jenkins-bot has submitted this change and it was merged. ( 
https://gerrit.wikimedia.org/r/395592 )

Change subject: Add background resource monitor task to training executors
......................................................................


Add background resource monitor task to training executors

We have executors getting killed by overrunning their memory
allocations, but no clue why that is happening. Training an entire 35M
observation set on a single jvm (local spark mode) needs only 10GB for
both heap and overhead (total rss reported by linux). Creating a DMatrix
and running training with the 35M observations only increases memory use
as reported by the os by ~3GB. But training that 35M observation set in
yarn split between three executors with 9GB of memory overhead each
typically works but sometimes yarn comes out and kills our process.
Intuitively we should be able to get by with much less memory overhead,
and leave significantly more memory out there for other hadoop processes
to use.

Add a thread on executors that perform training to regularly report
both heap usage and Rss info from /proc/$pid/status. While this wont
tell us exactly what is happening, it will at least hopefully give
some insight into how memory usage develops over time up to the point
that yarn decides to kiil our executors.

This intentionally is implemented in a "once per jvm" way which is a bit
odd but provides us the most information. Basically the first time an
executor performs training the thread is spun up and that thread keeps
running after the current task is complete, up until the executor itself
exits.

Change-Id: I71c121055ea94b997bc018da4fc0d4d86d63bf66
---
A 
jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/ResourceMonitorThread.scala
M 
jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
2 files changed, 70 insertions(+), 1 deletion(-)

Approvals:
  jenkins-bot: Verified
  DCausse: Looks good to me, approved



diff --git 
a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/ResourceMonitorThread.scala
 
b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/ResourceMonitorThread.scala
new file mode 100644
index 0000000..22ed87c
--- /dev/null
+++ 
b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/ResourceMonitorThread.scala
@@ -0,0 +1,54 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import java.lang.management.ManagementFactory
+import java.util.concurrent.atomic.AtomicBoolean
+
+import org.apache.commons.logging.LogFactory
+
+import scala.concurrent.duration.Duration
+import scala.io.Source
+
+class ResourceMonitorThread(reportEvery: Duration) extends Thread {
+  super.setDaemon(true)
+
+  private val keepChecking = new AtomicBoolean(true)
+  private val pid = 
ManagementFactory.getRuntimeMXBean.getName.split('@')(0).toInt
+  private val memoryBean = ManagementFactory.getMemoryMXBean
+  private val logger = LogFactory.getLog(this.getClass)
+
+  override def run(): Unit = {
+    if (!logger.isInfoEnabled) {
+      return
+    }
+    while (keepChecking.get()) {
+      report()
+      Thread.sleep(reportEvery.toMillis)
+    }
+  }
+
+  def stopChecking(): Unit = keepChecking.set(false)
+
+  private def report(): Unit = {
+    val rss = Source.fromFile(s"/proc/$pid/status").getLines()
+      .filter(_.startsWith("Rss"))
+      .mkString(", ")
+    logger.info(rss)
+    logger.info(memoryBean.getHeapMemoryUsage)
+  }
+}
diff --git 
a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
 
b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
index 2f64e15..ea18ff2 100644
--- 
a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ 
b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -17,8 +17,10 @@
 package ml.dmlc.xgboost4j.scala.spark
 
 import java.io.ByteArrayInputStream
+import java.util.concurrent.TimeUnit
 
 import scala.collection.mutable
+import scala.concurrent.duration.Duration
 import scala.util.Random
 import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, 
RabitTracker => PyRabitTracker}
 import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
@@ -30,6 +32,7 @@
 import org.apache.spark.sql.Dataset
 import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
 import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
+
 
 object TrackerConf {
   def apply(): TrackerConf = TrackerConf(0L, "python")
@@ -51,6 +54,14 @@
 
 object XGBoost extends Serializable {
   private val logger = LogFactory.getLog("XGBoostSpark")
+
+  // By using a lazy val on an object (singleton) we ensure this is only 
performed
+  // once per-jvm. It is only initialized in that jvm if accessed.
+  private lazy val monitor: ResourceMonitorThread = {
+    val m = new ResourceMonitorThread(Duration(10, TimeUnit.SECONDS))
+    m.start()
+    m
+  }
 
   private def fromDenseToSparseLabeledPoints(
       denseLabeledPoints: Iterator[XGBLabeledPoint],
@@ -127,12 +138,16 @@
       } else {
         null
       }
+
+      // Yes it's odd to access this but not do anything. We are ensuring the 
lazily
+      // initialized resource monitor is setup before we enter training.
+      monitor
+
       rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
       Rabit.init(rabitEnv)
       val watches = Watches(params,
         fromDenseToSparseLabeledPoints(labeledPoints, missing),
         fromBaseMarginsToArray(baseMargins), cacheFileName)
-
       try {
         val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
             .map(_.toString.toInt).getOrElse(0)

-- 
To view, visit https://gerrit.wikimedia.org/r/395592
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: merged
Gerrit-Change-Id: I71c121055ea94b997bc018da4fc0d4d86d63bf66
Gerrit-PatchSet: 4
Gerrit-Project: search/xgboost
Gerrit-Branch: master
Gerrit-Owner: EBernhardson <ebernhard...@wikimedia.org>
Gerrit-Reviewer: DCausse <dcau...@wikimedia.org>
Gerrit-Reviewer: EBernhardson <ebernhard...@wikimedia.org>
Gerrit-Reviewer: jenkins-bot <>

_______________________________________________
MediaWiki-commits mailing list
MediaWiki-commits@lists.wikimedia.org
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits

Reply via email to