EBernhardson has uploaded a new change for review. ( 
https://gerrit.wikimedia.org/r/394741 )

Change subject: [WIP] Bad ideas for improved DBN performance
......................................................................

[WIP] Bad ideas for improved DBN performance

I'm not sure this is a particularly great idea, but I wanted to explore
the performance limits of the JVM based DBN implementation. This brings
the original benchmark (90s in java, 3-4s in prior patch) to ~900ms. To
get a better idea on performance i increased the size of the benchmark:

* python: 616s
  - only ran once
* orig jvm: min: 21.7, max: 24.1 mean: 23.5s
  - 5 runs
  - 25x- 28x faster than python
* optimized jvm: min: 5.0s max: 5.3s mean: 5.2s
  - 5 runs
  - 116x - 123x faster than python
  - 4x - 5x faster than orig jvm

The improvements made were guided by profiling in visualvm and arn't
all that numerous:

* We were thrashing memory pretty hard at >1GB/sec. To reduce this add
  caches of our intermediate arrays. We are still thrashing memory
  pretty hard but not as bad.

* The caches of the intermediate arrays in scala Maps brought those
  maps up high in the profiler. Replace with arrays of queues. The
  backing linked list still shows up in profiling, but not as bad.

* DefaultMap.apply gets hit *alot* and was showing up in profiling.
  Replacing inner scala maps with java maps helped some. Further
  replacing java maps with trove4j primitive maps helped significantly.

* Find places where we were repeatedly hitting an array for the same
  item (for example getting something by s.queryId in a loop on the
  urls) and fetch it into a local var. Not sure this made much
  difference

visualvm now reports 80% of cpu time is spent in our own functions,
whereas before it was significantly lower. Mostly I just kept looking
for places where the supporting machinery was taking up cpu instead
of our calculations and kept replacing them until it was better.

Change-Id: I08b72b98f515a820675e1ef9b45dd8724cbd070e
---
M jvm/pom.xml
M jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala
M jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala
3 files changed, 246 insertions(+), 58 deletions(-)


  git pull ssh://gerrit.wikimedia.org:29418/search/MjoLniR 
refs/changes/41/394741/1

diff --git a/jvm/pom.xml b/jvm/pom.xml
index b2a7f71..f405975 100644
--- a/jvm/pom.xml
+++ b/jvm/pom.xml
@@ -141,6 +141,11 @@
             <version>3.0.1</version>
             <scope>test</scope>
         </dependency>
+        <dependency>
+            <groupId>net.sf.trove4j</groupId>
+            <artifactId>trove4j</artifactId>
+            <version>3.0.3</version>
+        </dependency>
     </dependencies>
     <repositories>
         <repository>
diff --git a/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala 
b/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala
index 12ef975..cda6778 100644
--- a/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala
+++ b/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala
@@ -9,6 +9,9 @@
   * A Dynamic Bayesian Network Click Model for Web Search Ranking - Olivier 
Chapelle and
   * Ya Zang - http://olivier.chapelle.cc/pub/DBN_www2009.pdf
   */
+import gnu.trove.iterator.TIntObjectIterator
+import gnu.trove.map.hash.TIntObjectHashMap
+
 import scala.collection.mutable
 import scala.util.parsing.json.JSON
 
@@ -19,15 +22,22 @@
 
   // This bit maps input queryies/results to array indexes to be used while 
calculating
   private var currentUrlId: Int = 0 // TODO: Why is first returned value 1 
instead of 0?
+  private val urlToIdMap: mutable.Map[String, Int] = mutable.Map()
+  def urlToId(key: String): Int = {
+    urlToIdMap.getOrElseUpdate(key, {
+      currentUrlId += 1
+      currentUrlId
+    })
+  }
+
   private var currentQueryId: Int = -1
-  private val urlToId: DefaultMap[String, Int] = new DefaultMap({ _ =>
-    currentUrlId += 1
-    currentUrlId
-  })
-  private val queryToId: DefaultMap[(String, String), Int] = new DefaultMap({ 
_ =>
-    currentQueryId += 1
-    currentQueryId
-  })
+  private val queryToIdMap: mutable.Map[(String, String), Int] = mutable.Map()
+  def queryToId(key: (String, String)): Int = {
+    queryToIdMap.getOrElseUpdate(key, {
+      currentQueryId += 1
+      currentQueryId
+    })
+  }
 
   def maxQueryId: Int = currentQueryId + 2
 
@@ -91,8 +101,8 @@
   }
 
   def toRelevances(urlRelevances: Array[Map[Int, UrlRel]]): 
Seq[RelevanceResult] = {
-    val idToUrl = urlToId.asMap.map(_.swap)
-    val idToQuery = queryToId.asMap.map(_.swap)
+    val idToUrl = urlToIdMap.map(_.swap)
+    val idToQuery = queryToIdMap.map(_.swap)
 
     urlRelevances.zipWithIndex.flatMap { case (d, queryId) =>
       val (query, region) = idToQuery(queryId)
@@ -101,6 +111,127 @@
         RelevanceResult(query, region, url, urlRel.a * urlRel.s)
       }
     }
+  }
+}
+
+class ArrayCache {
+  val QUEUE_1D_MAX = 20
+  private val queueMap1d: Array[mutable.Queue[Array[Double]]] = 
Array.fill(QUEUE_1D_MAX + 1){ mutable.Queue() }
+
+  def get1d(n: Int): Array[Double] = {
+    if (n > QUEUE_1D_MAX) {
+      new Array[Double](n)
+    } else {
+      val queue = queueMap1d(n)
+      if (queue.isEmpty) {
+        new Array[Double](n)
+      } else {
+        queue.dequeue()
+      }
+    }
+  }
+
+  def get1d(n: Int, default: Double): Array[Double] = {
+    if (n > QUEUE_1D_MAX) {
+      Array.fill(n)(default)
+    } else {
+      val queue = queueMap1d(n)
+      if (queue.isEmpty) {
+        Array.fill(n)(default)
+      } else {
+        val arr = queue.dequeue()
+        var i = 0
+        while (i < n) {
+          arr(i) = default
+          i += 1
+        }
+        arr
+      }
+    }
+  }
+
+  def put1d(arr: Array[Double]): Unit = {
+    if (arr.length <= QUEUE_1D_MAX) {
+      queueMap1d(arr.length) += arr
+    }
+  }
+
+  private val ALPHA_BETA_MAX = 21
+  private val alphaBetaMap: Array[mutable.Queue[Array[Array[Double]]]] = 
Array.fill(ALPHA_BETA_MAX + 1){
+    mutable.Queue()
+  }
+
+  def getAlphaBeta(n: Int): Array[Array[Double]] = {
+    if (n > ALPHA_BETA_MAX) {
+      Array.ofDim(n, 2)
+    } else {
+      val queue = alphaBetaMap(n)
+      if (queue.isEmpty) {
+        println(s"Allocate alpha/beta (n=$n)")
+        Array.ofDim(n, 2)
+      } else {
+        queue.dequeue()
+      }
+    }
+  }
+
+  def putAlphaBeta(arr: Array[Array[Double]]): Unit = {
+    if (arr.length <= ALPHA_BETA_MAX) {
+      alphaBetaMap(arr.length) += arr
+    }
+  }
+
+  private val POSITION_REL_MAX = 20
+  private val positionRelMap: Array[mutable.Queue[PositionRel]] = 
Array.fill(POSITION_REL_MAX + 1){ mutable.Queue() }
+
+  def getPositionRel(n: Int): PositionRel = {
+    if (n > POSITION_REL_MAX) {
+      new PositionRel(new Array[Double](n), new Array[Double](n))
+    } else {
+      val queue = positionRelMap(n)
+      if (queue.isEmpty) {
+        println(s"Allocate position rel (n=$n)")
+        new PositionRel(new Array[Double](n), new Array[Double](n))
+      } else {
+        queue.dequeue()
+      }
+    }
+  }
+
+  def putPositionRel(rel: PositionRel): Unit = {
+    if (rel.a.length <= POSITION_REL_MAX) {
+      positionRelMap(rel.a.length) += rel
+    }
+  }
+
+  private val UPDATE_MATRIX_MAX = 20
+  private val updateMatrixMap: 
Array[mutable.Queue[Array[Array[Array[Double]]]]] = 
Array.fill(UPDATE_MATRIX_MAX + 1){ mutable.Queue() }
+
+  def getUpateMatrix(n: Int): Array[Array[Array[Double]]] = {
+    if (n > UPDATE_MATRIX_MAX) {
+      Array.ofDim(n, 2, 2)
+    } else {
+      val queue = updateMatrixMap(n)
+      if (queue.isEmpty) {
+        println(s"Allocate update matrix (n=$n)")
+        Array.ofDim(n, 2, 2)
+      } else {
+        queue.dequeue()
+      }
+    }
+  }
+
+  def putUpdateMatrix(arr: Array[Array[Array[Double]]]): Unit = {
+    if (arr.length <= UPDATE_MATRIX_MAX) {
+      updateMatrixMap(arr.length) += arr
+    }
+  }
+
+  def clear(): Unit = {
+    queueMap1d.foreach(_.clear)
+    alphaBetaMap.foreach(_.clear)
+    positionRelMap.foreach(_.clear)
+    updateMatrixMap.foreach(_.clear)
   }
 }
 
@@ -155,26 +286,36 @@
 // it so requesting an item not in the map gets set to a default
 // value and then returned. This differs from withDefault which
 // expects to return an immutable value so doesn't set it into the map.
-class DefaultMap[K, V](default: K => V) extends Iterable[(K, V)] {
-  private val map = mutable.Map[K,V]()
+class DefaultMap[V](default: => V) {
+  private val map = new TIntObjectHashMap[V]()
 
-  def apply(key: K): V = {
-    map.get(key) match {
-      case Some(value) => value
-      case None =>
-        val value = default(key)
-        map.update(key, value)
-        value
+  def apply(key: Int): V = {
+    if (map.containsKey(key)) {
+      map.get(key)
+    } else {
+      val value: V = default
+      map.put(key, value)
+      value
     }
   }
 
-  override def iterator: Iterator[(K, V)] = map.iterator
-
   // converts to immutable scala Map
-  def asMap: Map[K, V] = map.toMap
+  def asMap: Map[Int, V] = {
+    val x = mutable.Map[Int, V]()
+    val iter = map.iterator()
+    while (iter.hasNext) {
+      iter.advance()
+      x.put(iter.key(), iter.value())
+    }
+    x.toMap
+  }
+
+  def iterator(): TIntObjectIterator[V] = map.iterator()
 }
 
 object DbnModel {
+  val doubleArrCache = new ArrayCache()
+
   /**
     * The forward-backward algorithm is used to to compute the posterior 
probabilities of the hidden variables.
     *
@@ -241,26 +382,30 @@
   def getForwardBackwardEstimates(rel: PositionRel, gamma: Double, clicks: 
Array[Boolean]): (Array[Array[Double]], Array[Array[Double]]) = {
     val N = clicks.length
     // alpha(k)(e) = P(C_1,...C_k-1,E_i=e|a_u, s_u, G) calculated forwards for 
C_1, then C_1,C_2, ...
-    val alpha = Array.ofDim[Double](N + 1, 2)
+    val alpha = doubleArrCache.getAlphaBeta(N + 1)
     // beta(k)(e) = P(C_k,...C_N|E_i=e, a_u, s_u, G) calculated backwards for 
C_10, then C_9, C_10, ...
-    val beta = Array.ofDim[Double](N + 1, 2)
+    val beta = doubleArrCache.getAlphaBeta(N + 1)
 
+    alpha(0)(0) = 0D
     alpha(0)(1) = 1D
     beta(N)(0) = 1D
     beta(N)(1) = 1D
 
     // Forwards (alpha) and backwards (beta) need the same probabilities as 
inputs so pre-calculate them.
     var k = 0
-    val updateMatrix: Array[Array[Array[Double]]] = 
Array.ofDim[Double](clicks.length, 2, 2)
+    val updateMatrix = doubleArrCache.getUpateMatrix(clicks.length)
     while (k < N) {
       val a_u = rel.a(k)
       val s_u = rel.s(k)
       if (clicks(k)) {
+        updateMatrix(k)(0)(0) = 0D
         updateMatrix(k)(0)(1) = (s_u + (1 - gamma) * (1 - s_u)) * a_u
+        updateMatrix(k)(1)(0) = 0D
         updateMatrix(k)(1)(1) = gamma * (1 - s_u) * a_u
       } else {
         updateMatrix(k)(0)(0) = 1D
         updateMatrix(k)(0)(1) = (1D - gamma) * (1D - a_u)
+        updateMatrix(k)(1)(0) = 0D
         updateMatrix(k)(1)(1) = gamma * (1D - a_u)
       }
       k += 1
@@ -286,6 +431,8 @@
       k += 1
     }
 
+    doubleArrCache.putUpdateMatrix(updateMatrix)
+
     (alpha, beta)
   }
 
@@ -301,7 +448,7 @@
     // varphi is the smoothing of the forwards and backwards. I think, based 
on wiki page on forward/backwards
     // algorithm, that varphi is then P(E_k|C_1,...,C_N,a_u,s_u,G) but not 
100% sure...
     var k = 0
-    val varphi: Array[Double] = new Array(alpha.length)
+    val varphi: Array[Double] = doubleArrCache.get1d(alpha.length)
     while (k < alpha.length) {
       val a = alpha(k)
       val b = beta(k)
@@ -311,7 +458,10 @@
       k += 1
     }
 
-    val sessionEstimate = new PositionRel(new Array[Double](N), new 
Array[Double](N))
+    doubleArrCache.putAlphaBeta(alpha)
+    doubleArrCache.putAlphaBeta(beta)
+
+    val sessionEstimate = doubleArrCache.getPositionRel(N)
     k = 0
     while (k < N) {
       val a_u = rel.a(k)
@@ -331,6 +481,9 @@
       }
       k += 1
     }
+
+    doubleArrCache.put1d(varphi)
+
     sessionEstimate
   }
 }
@@ -342,59 +495,89 @@
     // dimension and urlId in the second dimension. Because queries only 
reference
     // a subset of the known urls we use a map at the second level instead of
     // creating the entire matrix.
-    val urlRelevances: Array[DefaultMap[Int, UrlRel]] = 
Array.fill(config.maxQueryId) {
-      new DefaultMap[Int, UrlRel]({
-        _ => new UrlRel(config.defaultRel, config.defaultRel)
+    val urlRelevances: Array[DefaultMap[UrlRel]] = 
Array.fill(config.maxQueryId) {
+      new DefaultMap[UrlRel]({
+        new UrlRel(config.defaultRel, config.defaultRel)
       })
     }
 
     for (_ <- 0 until config.maxIterations) {
-      for ((d, queryId) <- eStep(urlRelevances, sessions).view.zipWithIndex) {
+      val urlRelFractions = eStep(urlRelevances, sessions)
+      var queryId = 0
+      while (queryId < urlRelFractions.length) {
+        val d = urlRelFractions(queryId)
         // M step
-        for ((urlId, relFractions) <- d) {
-          val rel = urlRelevances(queryId)(urlId)
+        val queryUrlRelevances = urlRelevances(queryId)
+        val iter = d.iterator()
+        while (iter.hasNext) {
+          iter.advance()
+          val urlId = iter.key()
+          val relFractions = iter.value()
+          val rel = queryUrlRelevances(urlId)
           // Convert our sums of per-session a_u and s_u into probabilities 
(domain of [0,1])
           // attracted / (attracted + not-attracted)
           rel.a = relFractions.a(1) / (relFractions.a(1) + relFractions.a(0))
           // satisfied / (satisfied + not-satisfied)
           rel.s = relFractions.s(1) / (relFractions.s(1) + relFractions.s(0))
+
+          DbnModel.doubleArrCache.put1d(relFractions.a)
+          DbnModel.doubleArrCache.put1d(relFractions.s)
         }
+        queryId += 1
       }
     }
 
+    DbnModel.doubleArrCache.clear()
     urlRelevances.map(_.asMap)
   }
 
   // E step
-  private def eStep(urlRelevances: Array[DefaultMap[Int, UrlRel]], sessions: 
Seq[SessionItem])
-  : Array[DefaultMap[Int, UrlRelFrac]] = {
+  private def eStep(urlRelevances: Array[DefaultMap[UrlRel]], sessions: 
Seq[SessionItem])
+  : Array[DefaultMap[UrlRelFrac]] = {
     // urlRelFraction(queryId)(urlId)
-    val urlRelFractions: Array[DefaultMap[Int, UrlRelFrac]] = 
Array.fill(config.maxQueryId) {
-      new DefaultMap[Int, UrlRelFrac]({
-        _ => new UrlRelFrac(Array.fill(2)(1D), Array.fill(2)(1D))
+    val urlRelFractions: Array[DefaultMap[UrlRelFrac]] = 
Array.fill(config.maxQueryId) {
+      new DefaultMap[UrlRelFrac]({
+        new UrlRelFrac(
+          DbnModel.doubleArrCache.get1d(2, 1D),
+          DbnModel.doubleArrCache.get1d(2, 1D)
+        )
       })
     }
 
-    for (s <- sessions) {
-      val positionRelevances = new PositionRel(
-        s.urlIds.map(urlRelevances(s.queryId)(_).a),
-        s.urlIds.map(urlRelevances(s.queryId)(_).s)
-      )
+    var sidx = 0
+    while (sidx < sessions.length) {
+      val s = sessions(sidx)
+      val positionRelevances = 
DbnModel.doubleArrCache.getPositionRel(s.urlIds.length)
+      var i = 0
+      val urlRelQuery = urlRelevances(s.queryId)
+      while (i < s.urlIds.length) {
+        val urlRel = urlRelQuery(s.urlIds(i))
+        positionRelevances.a(i) = urlRel.a
+        positionRelevances.s(i) = urlRel.s
+        i += 1
+      }
 
       val sessionEstimate = DbnModel.getSessionEstimate(positionRelevances, 
gamma, s.clicks)
-      for ((urlId, k) <- s.urlIds.view.zipWithIndex) {
+      DbnModel.doubleArrCache.putPositionRel(positionRelevances)
+      val queryUrlRelFrac = urlRelFractions(s.queryId)
+      var k = 0
+      while (k < s.urlIds.length) {
+        var urlId = s.urlIds(k)
         // update attraction
-        val rel = urlRelFractions(s.queryId)(urlId)
+        val rel = queryUrlRelFrac(urlId)
         val estA = sessionEstimate.a(k)
-        rel.a(0) += (1 - estA)
+        rel.a(0) += 1 - estA
         rel.a(1) += estA
         if (s.clicks(k)) {
           // update satisfaction
           val estS = sessionEstimate.s(k)
-          rel.s(0) += (1 - estS)
+          rel.s(0) += 1 - estS
           rel.s(1) += estS
         }
+        k += 1
       }
+      DbnModel.doubleArrCache.putPositionRel(sessionEstimate)
+      sidx += 1
     }
     urlRelFractions
   }
diff --git a/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala 
b/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala
index 9d883a2..efaf8d8 100644
--- a/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala
+++ b/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala
@@ -129,17 +129,17 @@
     println(s"Took ${took/1000000}ms")
 
     // Create a datafile that python clickmodels can read in to have fair 
comparison
-    //import java.io.File
-    //import java.io.PrintWriter
+    import java.io.File
+    import java.io.PrintWriter
 
-    //val writer = new PrintWriter(new File("/tmp/dbn.clickmodels"))
-    //for ( s <- sessions) {
-    //  // poor mans json serialization
-    //  val layout = Array.fill(s.urlIds.length)("false").mkString("[", ",", 
"]")
-    //  val clicks = s.clicks.map(_.toString).mkString("[", ",", "]")
-    //  val urls = s.urlIds.map(_.toString).mkString("[\"", "\",\"", "\"]")
-    //  writer.write(s"0\t${s.queryId}\tregion\t0\t$urls\t$layout\t$clicks\n")
-    //}
-    //writer.close()
+    val writer = new PrintWriter(new File("/tmp/dbn.clickmodels"))
+    for ( s <- sessions) {
+      // poor mans json serialization
+      val layout = Array.fill(s.urlIds.length)("false").mkString("[", ",", "]")
+      val clicks = s.clicks.map(_.toString).mkString("[", ",", "]")
+      val urls = s.urlIds.map(_.toString).mkString("[\"", "\",\"", "\"]")
+      writer.write(s"0\t${s.queryId}\tregion\t0\t$urls\t$layout\t$clicks\n")
+    }
+    writer.close()
   }
 }

-- 
To view, visit https://gerrit.wikimedia.org/r/394741
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: newchange
Gerrit-Change-Id: I08b72b98f515a820675e1ef9b45dd8724cbd070e
Gerrit-PatchSet: 1
Gerrit-Project: search/MjoLniR
Gerrit-Branch: master
Gerrit-Owner: EBernhardson <ebernhard...@wikimedia.org>

_______________________________________________
MediaWiki-commits mailing list
MediaWiki-commits@lists.wikimedia.org
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits

Reply via email to