jenkins-bot has submitted this change and it was merged. ( 
https://gerrit.wikimedia.org/r/394741 )

Change subject: Improve DBN performance another 10x
......................................................................


Improve DBN performance another 10x

I'm not sure this is a particularly great idea, but I wanted to explore
the performance limits of the JVM based DBN implementation. This brings
the original benchmark (90s in java, 3-4s in prior patch) to ~750ms. To
get a better idea on performance i increased the size of the benchmark:

* python: 616s
  - only ran once
* orig jvm: min: 21.7, max: 24.1 mean: 23.5s
  - 5 runs
  - 25x- 28x faster than python
* optimized jvm: min: 1.7s max: 2.1s mean:1.8s
  - 5 runs
  - 293x - 362x faster than python
  - 10x - 14x faster than orig jvm

The improvements made were guided by profiling in visualvm and arn't
all that numerous:

* We were thrashing memory pretty hard at >1GB/sec. To reduce this
  pre-allocate our various intermediate arrays. I opted to keep the
  original control flow, returning those arrays instead of accessing
  the instance variables everywhere for some level of clarity, but tbh
  this is probably still harder to follow than it was originally.

* DefaultMap.apply gets hit *alot* and was showing up in profiling.
  Instead of using maps assign urlId's per-query so they always start
  at 0, and generate Arrays of appropriate size instead of maps with
  default values.

* Remove all scala style iteration in favor of while loops, as we are
  using arrays that can be accessed in order. This combined with above
  has completely removed all allocation while running the dbn. visualvm
  shows completely flat memory usage.

* Find places where we were repeatedly hitting an array for the same
  item (for example getting something by s.queryId in a loop on the
  urls) and fetch it into a local var. Not sure this made much
  difference. Limit this to mostly just urls, as its relatively minor
  and saving 2 accesses is meh, but saving 20 accesses per url list
  might help a little.

* Cache some common calculations, like 1 - gamma, instead of repeating.
  Probably makes very little difference, but didn't seem like it could
  hurt.

Tried (without profiling suggesting its a problem) but little
or no benefit. Combined reduces runtime to ~1.5s, or 1.25x faster:

* Convert clicks from Array[Boolean] => scala.collections.BitSet (worse)
* Agressively prevent re-accessing the same array indexes, instead
  creating local variables holding intermediates. (small benefit)
* Iterate in reverse order, so our conditions are always (x >= 0).
  (maybe small benefit, mostly from not accessing .length every
  iteration)
* Build the various arrays explicitly with while loops, instead of
  x.map { .. }.toArray (very small benefit, and only on setup)
* More agressively prevent duplicate math, such as N - 1 - k being
  done 4 times per position in calcForwardsBackwards (very small
  benefit)

visualvm now reports 100% of cpu time is spent in our own functions,
whereas before it was significantly lower. Mostly I just kept looking
for places where the supporting machinery was taking up cpu instead
of our calculations and kept replacing them until it was better. The
split for cpu usage now is basically:

* calcForwardBackwards: 50%
* eStep: 30%
* getSessionEstimate: 20%

Change-Id: I08b72b98f515a820675e1ef9b45dd8724cbd070e
---
M jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala
M jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala
2 files changed, 293 insertions(+), 170 deletions(-)

Approvals:
  jenkins-bot: Verified
  DCausse: Looks good to me, approved



diff --git a/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala 
b/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala
index 7f50477..faac7dc 100644
--- a/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala
+++ b/jvm/src/main/scala/org/wikimedia/search/mjolnir/DBN.scala
@@ -8,29 +8,41 @@
   *
   * A Dynamic Bayesian Network Click Model for Web Search Ranking - Olivier 
Chapelle and
   * Ya Zang - http://olivier.chapelle.cc/pub/DBN_www2009.pdf
+  *
+  * It's worth noting that all of the math notes in this file are post-hoc. The
+  * implementation was ported from python clickmodels by Aleksandr Chuklin and 
the
+  * notes on math were added in an attempt to understand why the 
implementation works.
   */
 import scala.collection.mutable
 import org.json4s.{JArray, JBool, JString}
 import org.json4s.jackson.JsonMethods
 
-case class SessionItem(queryId: Int, urlIds: Array[Int], clicks: 
Array[Boolean])
-case class RelevanceResult(query: String, region: String, url: String, 
relevance: Double)
+class SessionItem(val queryId: Int, val urlIds: Array[Int], val clicks: 
Array[Boolean])
+class RelevanceResult(val query: String, val region: String, val url: String, 
val relevance: Double)
 
 class InputReader(minDocsPerQuery: Int, serpSize: Int, discardNoClicks: 
Boolean) {
 
   // This bit maps input queryies/results to array indexes to be used while 
calculating
-  private var currentUrlId: Int = 0 // TODO: Why is first returned value 1 
instead of 0?
-  private var currentQueryId: Int = -1
-  private val urlToId: DefaultMap[String, Int] = new DefaultMap({ _ =>
-    currentUrlId += 1
-    currentUrlId
-  })
-  private val queryToId: DefaultMap[(String, String), Int] = new DefaultMap({ 
_ =>
-    currentQueryId += 1
-    currentQueryId
-  })
+  private val queryIdToNextUrlId: mutable.Map[Int, Int] = mutable.Map()
+  private val queryIdToUrlToIdMap: mutable.Map[Int, mutable.Map[String, Int]] 
= mutable.Map()
 
-  def maxQueryId: Int = currentQueryId + 2
+  def urlToId(queryId: Int, url: String): Int = {
+    val urlToIdMap = queryIdToUrlToIdMap.getOrElseUpdate(queryId, { 
mutable.Map() })
+    urlToIdMap.getOrElseUpdate(url, {
+      var nextUrlId = queryIdToNextUrlId.getOrElse(queryId, 0)
+      queryIdToNextUrlId(queryId) = nextUrlId + 1
+      nextUrlId
+    })
+  }
+
+  private var nextQueryId: Int = 0
+  private val queryToIdMap: mutable.Map[(String, String), Int] = mutable.Map()
+  def queryToId(key: (String, String)): Int = {
+    queryToIdMap.getOrElseUpdate(key, {
+      nextQueryId += 1
+      nextQueryId - 1
+    })
+  }
 
   private def parseJsonBooleanArray(json: String): Array[Boolean] = {
     JsonMethods.parse(json) match {
@@ -58,15 +70,24 @@
 
   def makeSessionItem(query: String, region: String, urls: Array[String], 
clicks: Array[Boolean]): Option[SessionItem] = {
     val n = math.min(serpSize, urls.length)
-    val hasClicks = clicks.take(n).exists { x => x}
+    val allClicks: Array[Boolean] = if (clicks.length >= n) {
+      clicks.take(n)
+    } else {
+      // pad clicks up to n with false
+      val c: Array[Boolean] = Array.fill(n)(false)
+      clicks.zipWithIndex.foreach { case (clicked, i) => c(i) = clicked }
+      c
+    }
+
+    val hasClicks = allClicks.take(n).exists { x => x}
     if (urls.length < minDocsPerQuery ||
         (discardNoClicks && !hasClicks)
     ) {
       None
     } else {
       val queryId = queryToId((query, region))
-      val urlIds = urls.take(n).map { url => urlToId(url) }
-      Some(SessionItem(queryId, urlIds, clicks.take(n)))
+      val urlIds = urls.take(n).map { url => urlToId(queryId, url) }
+      Some(new SessionItem(queryId, urlIds, allClicks))
     }
   }
 
@@ -79,8 +100,9 @@
   val PIECE_CLICKS = 6
 
   // TODO: Ideally dont use this and make session items directly without extra 
ser/deser overhead
+  // This is primarily for compatability with the input format of python 
clickmodels library.
   def read(f: Iterator[String]): Seq[SessionItem] = {
-    f.flatMap { line =>
+    val sessions = f.flatMap { line =>
       val pieces = line.split("\t")
       val query: String = pieces(PIECE_QUERY)
       val region = pieces(PIECE_REGION)
@@ -89,23 +111,40 @@
 
       makeSessionItem(query, region, urls, clicks)
     }.toSeq
+    // Guarantee we return a materialized collection and not a lazy one
+    // which wont have properly updated our max query/url ids
+    sessions.last
+    sessions
   }
 
-  def toRelevances(urlRelevances: Array[Map[Int, UrlRel]]): 
Seq[RelevanceResult] = {
-    val idToUrl = urlToId.asMap.map(_.swap)
-    val idToQuery = queryToId.asMap.map(_.swap)
+  def toRelevances(urlRelevances: Array[Array[UrlRel]]): Seq[RelevanceResult] 
= {
+    val queryToUrlIdToUrl = queryIdToUrlToIdMap.map { case (queryId, urlToId) 
=>
+      (queryId, urlToId.map(_.swap))
+    }
+    val queryIdToQuery = queryToIdMap.map(_.swap)
 
     urlRelevances.zipWithIndex.flatMap { case (d, queryId) =>
-      val (query, region) = idToQuery(queryId)
-      d.map { case (urlId, urlRel) =>
-        val url = idToUrl(urlId)
-        RelevanceResult(query, region, url, urlRel.a * urlRel.s)
+      val (query, region) = queryIdToQuery(queryId)
+      val urlIdToUrl = queryToUrlIdToUrl(queryId)
+      d.zipWithIndex.view.map { case (urlRel, urlId) =>
+        val url = urlIdToUrl(urlId)
+        new RelevanceResult(query, region, url, urlRel.a * urlRel.s)
       }
     }
   }
+
+
+  def config(defaultRel: Double, maxIterations: Int): Config = {
+    val maxUrlIds: Array[Int] = (0 until nextQueryId).map { queryId =>
+      queryIdToNextUrlId(queryId) - 1
+    }.toArray
+    new Config(nextQueryId - 1, defaultRel, maxIterations, serpSize, maxUrlIds)
+  }
 }
 
-class Config(val maxQueryId: Int, val defaultRel: Double, val maxIterations: 
Int)
+class Config(val maxQueryId: Int, val defaultRel: Double, val maxIterations: 
Int, val serpSize: Int, val maxUrlIds: Array[Int]) {
+  val maxUrlId: Int = maxUrlIds.max
+}
 
 // Some definitions:
 //
@@ -152,30 +191,112 @@
   clicks: Array[Double])
 
 
-// Bit of a hack ... but to make things easy to deal with this makes
-// it so requesting an item not in the map gets set to a default
-// value and then returned. This differs from withDefault which
-// expects to return an immutable value so doesn't set it into the map.
-class DefaultMap[K, V](default: K => V) extends Iterable[(K, V)] {
-  private val map = mutable.Map[K,V]()
+class DbnModel(gamma: Double, config: Config) {
+  val invGamma: Double = 1D - gamma
 
-  def apply(key: K): V = {
-    map.get(key) match {
-      case Some(value) => value
-      case None =>
-        val value = default(key)
-        map.update(key, value)
-        value
+  def train(sessions: Seq[SessionItem]): Array[Array[UrlRel]] = {
+    // This is basically a multi-dimensional array with queryId in the first
+    // dimension and urlId in the second dimension. Because queries only 
reference
+    // a subset of the known urls we use a map at the second level instead of
+    // creating the entire matrix.
+    val urlRelevances: Array[Array[UrlRel]] = (0 to config.maxQueryId).map { 
queryId =>
+      (0 to config.maxUrlIds(queryId)).map { _ => new 
UrlRel(config.defaultRel, config.defaultRel) }.toArray
+    }.toArray
+
+    for (_ <- 0 until config.maxIterations) {
+      val urlRelFractions = eStep(urlRelevances, sessions)
+      var queryId = config.maxQueryId
+      while (queryId >= 0) {
+        // M step
+        val queryUrlRelevances = urlRelevances(queryId)
+        val queryUrlRelFractions = urlRelFractions(queryId)
+        var urlId = config.maxUrlIds(queryId)
+        // iterate over urls related to the query
+        while (urlId >= 0) {
+          val relFractions = queryUrlRelFractions(urlId)
+          val rel = queryUrlRelevances(urlId)
+          // Convert our sums of per-session a_u and s_u into probabilities 
(domain of [0,1])
+          // attracted / (attracted + not-attracted)
+          rel.a = relFractions.a(1) / (relFractions.a(1) + relFractions.a(0))
+          // satisfied / (satisfied + not-satisfied)
+          rel.s = relFractions.s(1) / (relFractions.s(1) + relFractions.s(0))
+
+          // Reset rel-fractions for next iteration
+          relFractions.a(0) = 1D
+          relFractions.a(1) = 1D
+          relFractions.s(0) = 1D
+          relFractions.s(1) = 1D
+          urlId -= 1
+        }
+        queryId -= 1
+      }
     }
+
+    urlRelevances
   }
 
-  override def iterator: Iterator[(K, V)] = map.iterator
+  val positionRelevances = new PositionRel(new Array[Double](config.serpSize), 
new Array[Double](config.serpSize))
+  // By pre-allocating we only have to fill the maps on the first iteration. 
After that we avoid
+  // allocation and reuse what we already know we need. It's important that 
the train method reset these
+  // to 1D after each iteration.
+  //
+  // urlRelFraction(queryId)(urlId)
+  val urlRelFractions: Array[Array[UrlRelFrac]] = (0 to config.maxQueryId).map 
{ queryId =>
+    (0 to config.maxUrlIds(queryId)).map { _=>
+      new UrlRelFrac(Array.fill(2)(1D), Array.fill(2)(1D))
+    }.toArray
+  }.toArray
 
-  // converts to immutable scala Map
-  def asMap: Map[K, V] = map.toMap
-}
+  // E step
+  private def eStep(urlRelevances: Array[Array[UrlRel]], sessions: 
Seq[SessionItem])
+  : Array[Array[UrlRelFrac]] = {
+    var sidx = 0
+    while (sidx < sessions.length) {
+      val s = sessions(sidx)
+      var i = 0
+      val urlRelQuery = urlRelevances(s.queryId)
+      val N = Math.min(config.serpSize, s.urlIds.length)
+      while (i < N) {
+        val urlRel = urlRelQuery(s.urlIds(i))
+        positionRelevances.a(i) = urlRel.a
+        positionRelevances.s(i) = urlRel.s
+        i += 1
+      }
 
-object DbnModel {
+      val sessionEstimate = getSessionEstimate(positionRelevances, s.clicks)
+      val queryUrlRelFrac = urlRelFractions(s.queryId)
+      i = 0
+      while (i < N) {
+        var urlId = s.urlIds(i)
+        // update attraction
+        val rel = queryUrlRelFrac(urlId)
+        val estA = sessionEstimate.a(i)
+        rel.a(0) += 1 - estA
+        rel.a(1) += estA
+        if (s.clicks(i)) {
+          // update satisfaction
+          val estS = sessionEstimate.s(i)
+          rel.s(0) += 1 - estS
+          rel.s(1) += estS
+        }
+        i += 1
+      }
+      sidx += 1
+    }
+    urlRelFractions
+  }
+
+  // To keep from allocating while running the DBN create our intermediate
+  // arrays at the largest size that might be needed. We must be careful to
+  // never calculate based on the length of this, but instead of the lengths
+  // of the input.
+  val updateMatrix: Array[Array[Array[Double]]] = Array.ofDim(config.serpSize, 
2, 2)
+  // alpha(i)(e) = P(C_1,...C_{i-1},E_i=e|a_u,s_u,G) calculated forwards for 
C_1, then C_1,C_2, ...
+  val alpha:Array[Array[Double]] = Array.ofDim(config.serpSize + 1, 2)
+  // beta(i)(e) = P(C_{i+1},...C_N|E_i=e,a_u,s_u,G) calculated backwards for 
C_10, then C_9, C_10, ...
+  val beta: Array[Array[Double]] = Array.ofDim(config.serpSize + 1, 2)
+  val varphi: Array[Double] = new Array(config.serpSize + 1)
+
   /**
     * The forward-backward algorithm is used to to compute the posterior 
probabilities of the hidden variables.
     *
@@ -239,29 +360,28 @@
     * P(E_{i+1}=0|E_i=0,S_i=1) = 1 (5g)
     * P(E_{i+1}=1|E_i=0,S_i=1) = 0 (5g)
     */
-  def getForwardBackwardEstimates(rel: PositionRel, gamma: Double, clicks: 
Array[Boolean]): (Array[Array[Double]], Array[Array[Double]]) = {
-    val N = clicks.length
-    // alpha(i)(e) = P(C_1,...C_{i-1},E_i=e|a_u,s_u,G) calculated forwards for 
C_1, then C_1,C_2, ...
-    val alpha = Array.ofDim[Double](N + 1, 2)
-    // beta(i)(e) = P(C_{i+1},...C_N|E_i=e,a_u,s_u,G) calculated backwards for 
C_10, then C_9, C_10, ...
-    val beta = Array.ofDim[Double](N + 1, 2)
+  def calcForwardBackwardEstimates(rel: PositionRel, clicks: Array[Boolean]): 
Unit = {
+    val N = Math.min(config.serpSize, clicks.length)
 
+    //always 0: alpha(0)(0) = 0D
     alpha(0)(1) = 1D
     beta(N)(0) = 1D
     beta(N)(1) = 1D
 
     // Forwards (alpha) and backwards (beta) need the same probabilities as 
inputs so pre-calculate them.
     var i = 0
-    val updateMatrix: Array[Array[Array[Double]]] = 
Array.ofDim[Double](clicks.length, 2, 2)
     while (i < N) {
       val a_u = rel.a(i)
       val s_u = rel.s(i)
       if (clicks(i)) {
-        updateMatrix(i)(0)(1) = (s_u + (1 - gamma) * (1 - s_u)) * a_u
+        updateMatrix(i)(0)(0) = 0D
+        updateMatrix(i)(0)(1) = (s_u + invGamma * (1 - s_u)) * a_u
+        // always 0: updateMatrix(i)(1)(0) = 0D
         updateMatrix(i)(1)(1) = gamma * (1 - s_u) * a_u
       } else {
         updateMatrix(i)(0)(0) = 1D
-        updateMatrix(i)(0)(1) = (1D - gamma) * (1D - a_u)
+        updateMatrix(i)(0)(1) = invGamma * (1D - a_u)
+        // always 0: updateMatrix(i)(1)(0) = 0D
         updateMatrix(i)(1)(1) = gamma * (1D - a_u)
       }
       i += 1
@@ -269,135 +389,75 @@
 
     i = 0
     while (i < N) {
-        // alpha(i+1)(e) = sum for e' in {0,1} of alpha(i)(e') * 
updateMatrix(i)(e)(e')
-        alpha(i + 1)(0) =
-          alpha(i)(0) * updateMatrix(i)(0)(0) +
-          alpha(i)(1) * updateMatrix(i)(0)(1)
-        alpha(i + 1)(1) =
-          alpha(i)(0) * updateMatrix(i)(1)(0) +
-          alpha(i)(1) * updateMatrix(i)(1)(1)
+      // alpha(i+1)(e) = sum for e' in {0,1} of alpha(i)(e') * 
updateMatrix(i)(e)(e')
+      alpha(i + 1)(0) =
+        alpha(i)(0) * updateMatrix(i)(0)(0) +
+        alpha(i)(1) * updateMatrix(i)(0)(1)
+      alpha(i + 1)(1) =
+        // always 0: alpha(i)(0) * updateMatrix(i)(1)(0) +
+        alpha(i)(1) * updateMatrix(i)(1)(1)
 
-        // beta(N-1-i)(e) = sum for e' in {0,1} of beta(N-1-i)(e') * 
updateMatrix(i)(e)(e')
-        beta(N - 1 - i)(0) =
-          beta(N - i)(0) * updateMatrix(N - 1 - i)(0)(0) +
-          beta(N - i)(1) * updateMatrix(N - 1 - i)(1)(0)
-        beta(N - 1 - i)(1) =
-          beta(N - i)(0) * updateMatrix(N - 1 - i)(0)(1) +
-          beta(N - i)(1) * updateMatrix(N - 1 - i)(1)(1)
+      // beta(N-1-i)(e) = sum for e' in {0,1} of beta(N-1-i)(e') * 
updateMatrix(i)(e)(e')
+      beta(N - 1 - i)(0) =
+        beta(N - i)(0) * updateMatrix(N - 1 - i)(0)(0)
+        // always 0: + beta(N - i)(1) * updateMatrix(N - 1 - i)(1)(0)
+      beta(N - 1 - i)(1) =
+        beta(N - i)(0) * updateMatrix(N - 1 - i)(0)(1) +
+        beta(N - i)(1) * updateMatrix(N - 1 - i)(1)(1)
       i += 1
     }
 
-    (alpha, beta)
+    // (alpha, beta)
   }
 
+  var sessionEstimate = new PositionRel(new Array[Double](config.serpSize), 
new Array[Double](config.serpSize))
   // Returns
-  //  a: P(A_i|C_i,G) - Probability of attractiveness at position k 
conditioned on clicked and gamma
-  //  s: P(S_i|C_i,G) - Probability of satisfaction at position k conditioned 
on clicked and gamma
-  def getSessionEstimate(rel: PositionRel, gamma: Double, clicks: 
Array[Boolean]): PositionRel = {
-    val N = clicks.length
-    // alpha(i)(e) is P(C_1,...,C_{i-1},E_i=e|a_u,s_u,G)
+  //  a: P(A_i|C_i,G) - Probability of attractiveness at position i 
conditioned on clicked and gamma
+  //  s: P(S_i|C_i,G) - Probability of satisfaction at position i conditioned 
on clicked and gamma
+  def getSessionEstimate(rel: PositionRel, clicks: Array[Boolean]): 
PositionRel = {
+    val N = Math.min(config.serpSize, clicks.length)
+
+    // This sets the instance variables alpha/beta
+    // alpha(i)(e) is P(C_1,...,C_{k-1},E_i=e|a_u,s_u,G)
     // beta(i)(e) is P(C_{i+1},...,C_N|E_i=e,a_u,s_u,G)
-    val (alpha, beta) = DbnModel.getForwardBackwardEstimates(rel, gamma, 
clicks)
+    calcForwardBackwardEstimates(rel, clicks)
 
     // varphi is the smoothing of the forwards and backwards. I think, based 
on wiki page on forward/backwards
     // algorithm, that varphi is then P(E_i|C_1,...,C_N,a_u,s_u,G) but not 
100% sure...
-    var k = 0
-    val varphi: Array[Double] = new Array(alpha.length)
-    while (k < alpha.length) {
-      val a = alpha(k)
-      val b = beta(k)
+    var i = 0
+    while (i < N + 1) {
+      val a = alpha(i)
+      val b = beta(i)
       val ab0 = a(0) * b(0)
       val ab01 = ab0 + a(1) * b(1)
-      varphi(k) = ab0 / ab01
-      k += 1
+      varphi(i) = ab0 / ab01
+      i += 1
     }
 
-    val sessionEstimate = new PositionRel(new Array[Double](N), new 
Array[Double](N))
-    k = 0
-    while (k < N) {
-      val a_u = rel.a(k)
-      val s_u = rel.s(k)
-      // E_i_multiplier --- P(S_i=0|C_i)P(C_i|E_i=1) (eq 6)
-      if (clicks(k)) {
+    i = 0
+    while (i < N) {
+      val a_u = rel.a(i)
+      val s_u = rel.s(i)
+
+      // TODO: Clickmodels had this as S_i=0, but i'm pretty sure it's 
supposed to be =1
+      // based on the actual updates performed?
+      // E_i_multiplier --- P(S_i=0|C_i)P(C_i|E_i=1) (inverse of eq 6)
+      if (clicks(i)) {
         // if user clicked attraction = 1 (eq 5a)
-        sessionEstimate.a(k) = 1D
+        sessionEstimate.a(i) = 1D
         // if user clicked satisfaction is (why?)
         //   prob of examination of next * satisfaction / (satisfaction + 
liklihood of abandonment * dissatisfaction)
-        sessionEstimate.s(k) = varphi(k + 1) * s_u / (s_u + (1 - gamma) * (1 - 
s_u))
+        sessionEstimate.s(i) = varphi(i + 1) * s_u / (s_u + invGamma * (1 - 
s_u))
       } else {
-        // with no click attraction = attraction * prob of examination (why?)
-        sessionEstimate.a(k) = a_u * varphi(k)
+        // probability of no click when examined = attraction * prob of 
examination (why?)
+        sessionEstimate.a(i) = a_u * varphi(i)
         // with no click satisfaction = 0 (eq 5d)
-        sessionEstimate.s(k) = 0D
+        sessionEstimate.s(i) = 0D
       }
-      k += 1
+      i += 1
     }
+
     sessionEstimate
-  }
-}
-
-class DbnModel(gamma: Double, config: Config) {
-
-  def train(sessions: Seq[SessionItem]): Array[Map[Int, UrlRel]] = {
-    // This is basically a multi-dimensional array with queryId in the first
-    // dimension and urlId in the second dimension. Because queries only 
reference
-    // a subset of the known urls we use a map at the second level instead of
-    // creating the entire matrix.
-    val urlRelevances: Array[DefaultMap[Int, UrlRel]] = 
Array.fill(config.maxQueryId) {
-      new DefaultMap[Int, UrlRel]({
-        _ => new UrlRel(config.defaultRel, config.defaultRel)
-      })
-    }
-
-    for (_ <- 0 until config.maxIterations) {
-      for ((d, queryId) <- eStep(urlRelevances, sessions).view.zipWithIndex) {
-        // M step
-        for ((urlId, relFractions) <- d) {
-          val rel = urlRelevances(queryId)(urlId)
-          // Convert our sums of per-session a_u and s_u into probabilities 
(domain of [0,1])
-          // attracted / (attracted + not-attracted)
-          rel.a = relFractions.a(1) / (relFractions.a(1) + relFractions.a(0))
-          // satisfied / (satisfied + not-satisfied)
-          rel.s = relFractions.s(1) / (relFractions.s(1) + relFractions.s(0))
-        }
-      }
-    }
-
-    urlRelevances.map(_.asMap)
-  }
-
-  // E step
-  private def eStep(urlRelevances: Array[DefaultMap[Int, UrlRel]], sessions: 
Seq[SessionItem])
-  : Array[DefaultMap[Int, UrlRelFrac]] = {
-    // urlRelFraction(queryId)(urlId)
-    val urlRelFractions: Array[DefaultMap[Int, UrlRelFrac]] = 
Array.fill(config.maxQueryId) {
-      new DefaultMap[Int, UrlRelFrac]({
-        _ => new UrlRelFrac(Array.fill(2)(1D), Array.fill(2)(1D))
-      })
-    }
-
-    for (s <- sessions) {
-      val positionRelevances = new PositionRel(
-        s.urlIds.map(urlRelevances(s.queryId)(_).a),
-        s.urlIds.map(urlRelevances(s.queryId)(_).s)
-      )
-
-      val sessionEstimate = DbnModel.getSessionEstimate(positionRelevances, 
gamma, s.clicks)
-      for ((urlId, i) <- s.urlIds.view.zipWithIndex) {
-        // update attraction
-        val rel = urlRelFractions(s.queryId)(urlId)
-        val estA = sessionEstimate.a(i)
-        rel.a(0) += (1 - estA)
-        rel.a(1) += estA
-        if (s.clicks(i)) {
-          // update satisfaction
-          val estS = sessionEstimate.s(i)
-          rel.s(0) += (1 - estS)
-          rel.s(1) += estS
-        }
-      }
-    }
-    urlRelFractions
   }
 }
 
diff --git a/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala 
b/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala
index 6bfbaae..b51fc7f 100644
--- a/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala
+++ b/jvm/src/test/scala/org/wikimedia/search/mjolnir/DBNSuite.scala
@@ -1,7 +1,9 @@
 package org.wikimedia.search.mjolnir
 
 import org.scalatest.FunSuite
+
 import scala.io.Source
+import scala.util.Random
 
 class DBNSuite extends FunSuite {
   test("create session items") {
@@ -12,6 +14,46 @@
       Array(false, true, false))
 
     assert(item.isDefined)
+  }
+
+  test("session items are truncated to serpSize") {
+    val serpSize = 20
+    val ir = new InputReader(1, serpSize, true)
+    val urls = (0 to 30).map(_.toString).toArray
+    val clicks = Array.fill(30)(false)
+    clicks(2) = true
+
+    val maybeItem = ir.makeSessionItem("foo", "enwiki", urls, clicks)
+    assert(maybeItem.isDefined)
+    val item = maybeItem.get
+    assert(item.clicks.length == serpSize)
+    assert(item.urlIds.length == serpSize)
+  }
+
+  test("no urls gives no session item") {
+    val ir = new InputReader(1, 2, true)
+    val urls = new Array[String](0)
+    val clicks = new Array[Boolean](0)
+    assert(ir.makeSessionItem("foo", "enwiki", urls, clicks).isEmpty)
+  }
+
+  test("no clicks gives no session item") {
+    val ir = new InputReader(1, 2, true)
+    val urls = Array("a", "b", "c")
+    val clicks = Array(false, false, false)
+    assert(ir.makeSessionItem("foo", "enwiki", urls, clicks).isEmpty)
+  }
+
+  test("clicks are padded with false up to url count") {
+    val ir = new InputReader(1, 5, true)
+    val urls = (0 until 5).map(_.toString).toArray
+    val clicks = Array(false, true)
+    val maybeItem = ir.makeSessionItem("foo", "enwiki", urls, clicks)
+    assert(maybeItem.isDefined)
+    val item = maybeItem.get
+    assert(item.clicks.length == 5)
+    assert(item.clicks(1))
+    assert(item.clicks.map(if (_) 1 else 0).sum == 1)
   }
 
   test("create session item from line") {
@@ -29,7 +71,7 @@
     val file = Source.fromURL(getClass.getResource("/dbn.data"))
     val ir = new InputReader(1, 20, true)
     val sessions = ir.read(file.getLines())
-    val config = new Config(ir.maxQueryId, 0.5D, 1)
+    val config = ir.config(0.5D, 1)
     val model = new DbnModel(0.9D, config)
     val urlRelevances = model.train(sessions)
     val relevances = ir.toRelevances(urlRelevances)
@@ -61,6 +103,23 @@
     }
   }
 
+  test("providing more results than expected still works") {
+    val N = 30
+    val clicks = Array.fill(N)(false)
+    clicks(2) = true
+
+    val sessions = Seq(
+      new SessionItem(0, (0 until N).toArray, clicks),
+      new SessionItem(0, (0 until N).toArray, clicks)
+    )
+
+    val config = new Config(0, 0.5D, 2, 20, Array(N))
+    val model = new DbnModel(0.9D, config)
+    model.train(sessions)
+    // no exceptions thrown
+    assert(true)
+  }
+
   test("backwards forwards") {
     val rel = new PositionRel(
       Array.fill(20)(0.5D), Array.fill(20)(0.5D)
@@ -68,9 +127,10 @@
     val gamma = 0.9D
     val clicks = Array.fill(20)(false)
 
-    val foo = DbnModel.getForwardBackwardEstimates(rel, gamma, clicks)
-    val alpha = foo._1
-    val beta = foo._2
+    val model = new DbnModel(0.5D, new Config(0, 0.5D, 1, 20, Array(20)))
+    model.calcForwardBackwardEstimates(rel, clicks)
+    val alpha = model.alpha
+    val beta = model.beta
     val x = alpha(0)(0) * beta(0)(0) + alpha(0)(1) * beta(0)(1)
 
     val ok: Array[Boolean] = alpha.zip(beta).map { case (a: Array[Double], b: 
Array[Double]) =>
@@ -84,15 +144,16 @@
     // Values sourced from python clickmodels implementation
     val rel = new PositionRel(Array.fill(20)(0.5D), Array.fill(20)(0.5D))
     val clicks = Array.fill(20)(false)
+    val model = new DbnModel(0.9D, new Config(0, 0.5D, 1, 20, Array(20)))
 
     clicks(0) = true
-    var sessionEstimate = DbnModel.getSessionEstimate(rel, 0.9D, clicks)
+    var sessionEstimate = model.getSessionEstimate(rel, clicks)
     assert(math.abs(sessionEstimate.a.sum - 10.4370D) < 0.0001D)
     assert(math.abs(sessionEstimate.s.sum - 0.8461D) < 0.0001D)
     assert(math.abs(sessionEstimate.s.sum - sessionEstimate.s(0)) < 0.0001D)
 
     clicks(10) = true
-    sessionEstimate = DbnModel.getSessionEstimate(rel, 0.9D, clicks)
+    sessionEstimate = model.getSessionEstimate(rel, clicks)
     assert(math.abs(sessionEstimate.a.sum - 6.4347D) < 0.0001D)
     assert(math.abs(sessionEstimate.s.sum - 0.8457D) < 0.0001D)
     assert(math.abs(sessionEstimate.s.sum - sessionEstimate.s(0) - 
sessionEstimate.s(10)) < 0.0001D)
@@ -105,7 +166,7 @@
     val nIterations = 40
     val nResultsPerQuery = 20
 
-    val r = new scala.util.Random(0)
+    val r = new Random(0)
     val ir = new InputReader(10, 20, true)
     val sessions = (0 until nQueries).flatMap { query =>
       val urls: Array[String] = (0 until nResultsPerQuery).map { _ => 
r.nextInt.toString }.toArray
@@ -121,12 +182,14 @@
 
     assert(sessions.length == nQueries * nSessionsPerQuery)
 
-    val config = new Config(ir.maxQueryId, 0.5D, nIterations)
+    val config = ir.config(0.5D, nIterations)
     val dbn = new DbnModel(0.9D, config)
-    val start = System.nanoTime()
-    dbn.train(sessions)
-    val took = System.nanoTime() - start
-    println(s"Took ${took/1000000}ms")
+    (0 until 5).foreach { _ =>
+      val start = System.nanoTime()
+      dbn.train(sessions)
+      val took = System.nanoTime() - start
+      println(s"Took ${took / 1000000}ms")
+    }
 
     // Create a datafile that python clickmodels can read in to have fair 
comparison
     //import java.io.File

-- 
To view, visit https://gerrit.wikimedia.org/r/394741
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: merged
Gerrit-Change-Id: I08b72b98f515a820675e1ef9b45dd8724cbd070e
Gerrit-PatchSet: 4
Gerrit-Project: search/MjoLniR
Gerrit-Branch: master
Gerrit-Owner: EBernhardson <ebernhard...@wikimedia.org>
Gerrit-Reviewer: DCausse <dcau...@wikimedia.org>
Gerrit-Reviewer: EBernhardson <ebernhard...@wikimedia.org>
Gerrit-Reviewer: jenkins-bot <>

_______________________________________________
MediaWiki-commits mailing list
MediaWiki-commits@lists.wikimedia.org
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits

Reply via email to