Tjones has submitted this change and it was merged. ( 
https://gerrit.wikimedia.org/r/347038 )

Change subject: Add DBN training
......................................................................


Add DBN training

Adds mjolnir.dbn for training a DBN. Takes a DataFrame containing user
search sessions, trains the dbn, and outputs a DataFrame containing
(wikiid, norm_query, hit_page_id, relevance) rows. This seemed like a
nice and small place to start to figure out how to add unit test
pyspark.

The tests are pretty minor, and probably don't begin to test all the
edge cases, but we can work those out later. To run the tests bring up
the vagrant machine, cd into /vagrant, and run:

  bin/pytest --pyargs mjolnir

Bug: T162075
Change-Id: Id0dd1cb267b90b0aaca096f471e5b35ebd88f156
---
M README
A mjolnir/dbn.py
A mjolnir/spark.py
A mjolnir/test/__init__.py
A mjolnir/test/conftest.py
A mjolnir/test/fixtures/dbn_input.json
A mjolnir/test/test_dbn.py
D poc/data_process_dbn.py
8 files changed, 344 insertions(+), 130 deletions(-)

Approvals:
  Tjones: Verified; Looks good to me, approved



diff --git a/README b/README
index ae2700e..e0b5a60 100644
--- a/README
+++ b/README
@@ -9,3 +9,7 @@
 
 Targets pyspark 1.6.0 running on python 2.7
 
+== Other
+
+Documentation follows the numpy documentation guidelines:
+    https://github.com/numpy/numpy/blob/master/doc/HOWTO_DOCUMENT.rst.txt
diff --git a/mjolnir/dbn.py b/mjolnir/dbn.py
new file mode 100644
index 0000000..958f302
--- /dev/null
+++ b/mjolnir/dbn.py
@@ -0,0 +1,194 @@
+"""
+Implements training a Dynamic Bayesian Network, using the clickmodels library,
+within spark
+"""
+
+from clickmodels.inference import DbnModel
+from clickmodels.input_reader import InputReader
+import json
+import pyspark.sql
+from pyspark.sql import functions as F
+import mjolnir.spark
+
+
+def _deduplicate_hits(session_hits):
+    """Deduplicate multiple views of a hit by a single session.
+
+    A single session may have seen the same result list multiple times, for
+    example by clicking a link, then clicking back and clicking a second link.
+    Normalize that data together into a single record per hit_page_id even if
+    it was displayed to a session multiple times.
+
+    Parameters
+    ----------
+    session_hits : list
+        A list of hits seen by a single session.
+
+    Returns
+    -------
+    list
+        List of hits shown to a session de-duplicated to contain only one entry
+        per hit_page_id.
+    """
+    by_hit_page_id = {}
+    for hit in session_hits:
+        if hit.hit_page_id in by_hit_page_id:
+            by_hit_page_id[hit.hit_page_id].append(hit)
+        else:
+            by_hit_page_id[hit.hit_page_id] = [hit]
+
+    deduped = []
+    for hit_page_id, hits in by_hit_page_id.iteritems():
+        hit_positions = []
+        clicked = False
+        for hit in hits:
+            hit_positions.append(hit.hit_position)
+            clicked |= bool(hit.clicked)
+        deduped.append(pyspark.sql.Row(
+            hit_page_id=hit_page_id,
+            hit_position=sum(hit_positions) / float(len(hit_positions)),
+            clicked=clicked))
+    return deduped
+
+
+def _gen_dbn_input(iterator):
+    """Converts an iterator over spark rows into the DBN input format.
+
+    It is perhaps undesirable that we serialize into a string with json so
+    InputReader can deserialize, but it is not generic enough to avoid this
+    step.
+
+    Parameters
+    ----------
+    iterator : ???
+        iterator over pyspark.sql.Row. Each row must have a wikiid,
+        norm_query, and list of hits each containing hit_position,
+        hit_page_id and clicked.
+
+    Yields
+    -------
+    string
+        Line for a single item of the input iterator formatted for use
+        by clickmodels InputReader.
+    """
+    for row in iterator:
+        results = []
+        clicks = []
+        deduplicated = _deduplicate_hits(row.hits)
+        deduplicated.sort(key=lambda hit: hit.hit_position)
+        for hit in deduplicated:
+            results.append(str(hit.hit_page_id))
+            clicks.append(hit.clicked)
+        yield '\t'.join([
+            '0',  # unused identifier
+            row.norm_query,  # group the session belongs to
+            row.wikiid,  # region
+            '0',  # intent weight
+            json.dumps(results),  # hits displayed in session
+            json.dumps([False] * len(results)),  # layout (unused)
+            json.dumps(clicks)  # Was result clicked
+        ])
+
+
+def _extract_labels_from_dbn(model, reader):
+    """Extracts all learned labels from the model.
+
+    Paramseters
+    -----------
+    model : clickmodels.inference.DbnModel
+        A trained DBN model
+    reader : clickmodels.input_reader.InputReader
+        Reader that was used to build the list of SessionItem's model was
+        trained with.
+
+    Returns
+    -------
+    list of tuples
+        List of four value tuples each containing wikiid, norm_query,
+        hit_page_id and relevance.
+    """
+    # reader converted all the page ids into an internal id, flip the map so we
+    # can change them back. Not the most memory efficient, but it will do.
+    uid_to_url = {uid: url for url, uid in reader.url_to_id.iteritems()}
+    rows = []
+    for (norm_query, wikiid), qid in reader.query_to_id.iteritems():
+        for uid, data in model.urlRelevances[False][qid].iteritems():
+            relevance = data['a'] * data['s']
+            hit_page_id = int(uid_to_url[uid])
+            rows.append((wikiid, norm_query, hit_page_id, relevance))
+    return rows
+
+
+def train(df, dbn_config, num_partitions=200):
+    """Generate relevance labels for the provided dataframe.
+
+    Process the provided data frame to generate relevance scores for
+    all provided pairs of (wikiid, norm_query, hit_page_id). The input
+    DataFrame must have a row per hit_page_id that was seen by a session.
+
+    Parameters
+    ----------
+    df : pyspark.sql.DataFrame
+        User click logs with columns wikiid, norm_query, session_id,
+        hit_page_id, hit_position, clicked.
+    dbn_config : dict
+        Configuration needed by the DBN. See clickmodels documentation for more
+        information.
+    num_partitions : int
+        The number of partitions to split input data into for training.
+        Training will load the entire partition into python to feed into the
+        DBN, so a large enough number of partitions need to be used that we
+        don't blow out executor memory.
+
+    Returns
+    -------
+    spark.sql.DataFrame
+        DataFrame with columns wikiid, norm_query, hit_page_id, relevance.
+    """
+    mjolnir.spark.assert_columns(df, ['wikiid', 'norm_query', 'session_id',
+                                      'hit_page_id', 'hit_position', 
'clicked'])
+
+    def train_partition(iterator):
+        """Learn the relevance labels for a single DataFrame partition.
+
+        Before applying to a partition ensure that sessions for queries are not
+        split between multiple partitions.
+
+        Parameters
+        ----------
+        iterator : iterator over pyspark.sql.Row's.
+
+        Returns
+        -------
+        list of tuples
+            List of (wikiid, norm_query, hit_page_id, relevance) tuples.
+        """
+        reader = InputReader(dbn_config['MIN_DOCS_PER_QUERY'],
+                             dbn_config['MAX_DOCS_PER_QUERY'],
+                             False,
+                             dbn_config['SERP_SIZE'],
+                             False,
+                             discard_no_clicks=True)
+        sessions = reader(_gen_dbn_input(iterator))
+        dbn_config['MAX_QUERY_ID'] = reader.current_query_id + 1
+        model = DbnModel((0.9, 0.9, 0.9, 0.9), config=dbn_config)
+        model.train(sessions)
+        return _extract_labels_from_dbn(model, reader)
+
+    return (
+        df
+        # group and collect up the hits for individual (wikiid, norm_query,
+        # session_id) tuples to match how the dbn expects to receive data.
+        .groupby('wikiid', 'norm_query', 'session_id')
+        .agg(F.collect_list(F.struct('hit_position', 'hit_page_id', 
'clicked')).alias('hits'))
+        # Partition into small batches ensuring that all matching (wikiid,
+        # norm_query) rows end up on the same partition.
+        # TODO: The above groupby and this repartition both cause a shuffle, is
+        # it possible to make that a single shuffle? Could push the final level
+        # of grouping into python, but that could just as well end up worse?
+        .repartition(num_partitions, 'wikiid', 'norm_query')
+        # Run each partition through the DBN to generate relevance scores.
+        .mapPartitions(train_partition)
+        # Convert the rdd of tuples back into a DataFrame so the fields all
+        # have a name.
+        .toDF(['wikiid', 'norm_query', 'hit_page_id', 'relevance']))
diff --git a/mjolnir/spark.py b/mjolnir/spark.py
new file mode 100644
index 0000000..8202849
--- /dev/null
+++ b/mjolnir/spark.py
@@ -0,0 +1,18 @@
+"""
+Helper functions for dealing with pyspark
+"""
+
+
+def assert_columns(df, columns):
+    """ Raise an exception if the dataframe
+    does not contain the desired columns
+    Parameters
+    ----------
+    df : pyspark.sql.DataFrame
+    columns : list of strings
+        Set of columns that must be present in df
+    """
+    have = set(df.columns)
+    need = set(columns)
+    if not need.issubset(have):
+        raise ValueError("Missing columns in DataFrame: %s" % (", 
".join(need.difference(have))))
diff --git a/mjolnir/test/__init__.py b/mjolnir/test/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/mjolnir/test/__init__.py
diff --git a/mjolnir/test/conftest.py b/mjolnir/test/conftest.py
new file mode 100644
index 0000000..aa93ae6
--- /dev/null
+++ b/mjolnir/test/conftest.py
@@ -0,0 +1,45 @@
+import findspark
+findspark.init()
+
+import pytest
+import logging
+from pyspark import SparkContext, SparkConf
+from pyspark.sql import HiveContext
+
+
+def quiet_log4j():
+    logger = logging.getLogger('py4j')
+    logger.setLevel(logging.WARN)
+
+
+@pytest.fixture(scope="session")
+def spark_context(request):
+    """Fixture for creating a spark context.
+
+    Args:
+        request: pytest.FixtureRequest object
+
+    Returns:
+        SparkContext for tests
+    """
+    quiet_log4j()
+    conf = (
+        SparkConf()
+        .setMaster("local[2]")
+        .setAppName("pytest-pyspark-local-testing"))
+    sc = SparkContext(conf=conf)
+    yield sc
+    sc.stop()
+
+
+@pytest.fixture(scope="session")
+def hive_context(spark_context):
+    """Fixture for creating a Hive context.
+
+    Args:
+        spark_context: spark_context fixture
+
+    Returns:
+        HiveContext for tests
+    """
+    return HiveContext(spark_context)
diff --git a/mjolnir/test/fixtures/dbn_input.json 
b/mjolnir/test/fixtures/dbn_input.json
new file mode 100644
index 0000000..648a446
--- /dev/null
+++ b/mjolnir/test/fixtures/dbn_input.json
@@ -0,0 +1,33 @@
+{"wikiid": "foowiki", "norm_query": "test", "session_id": "abc", 
"hit_page_id": 1111, "hit_position": 1, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "test", "session_id": "abc", 
"hit_page_id": 2222, "hit_position": 2, "clicked": true}
+{"wikiid": "foowiki", "norm_query": "test", "session_id": "abc", 
"hit_page_id": 3333, "hit_position": 3, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "test", "session_id": "abc", 
"hit_page_id": 4444, "hit_position": 4, "clicked": false}
+
+{"wikiid": "foowiki", "norm_query": "test", "session_id": "def", 
"hit_page_id": 1111, "hit_position": 1, "clicked": true}
+{"wikiid": "foowiki", "norm_query": "test", "session_id": "def", 
"hit_page_id": 2222, "hit_position": 2, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "test", "session_id": "def", 
"hit_page_id": 3333, "hit_position": 3, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "test", "session_id": "def", 
"hit_page_id": 4444, "hit_position": 4, "clicked": false}
+
+{"wikiid": "foowiki", "norm_query": "test", "session_id": "ghi", 
"hit_page_id": 1111, "hit_position": 1, "clicked": true}
+{"wikiid": "foowiki", "norm_query": "test", "session_id": "ghi", 
"hit_page_id": 2222, "hit_position": 2, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "test", "session_id": "ghi", 
"hit_page_id": 3333, "hit_position": 3, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "test", "session_id": "ghi", 
"hit_page_id": 4444, "hit_position": 4, "clicked": false}
+
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "abc", 
"hit_page_id": 1111, "hit_position": 1, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "abc", 
"hit_page_id": 2222, "hit_position": 2, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "abc", 
"hit_page_id": 3333, "hit_position": 3, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "abc", 
"hit_page_id": 4444, "hit_position": 4, "clicked": true}
+
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "def", 
"hit_page_id": 1111, "hit_position": 1, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "def", 
"hit_page_id": 2222, "hit_position": 2, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "def", 
"hit_page_id": 3333, "hit_position": 3, "clicked": true}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "def", 
"hit_page_id": 4444, "hit_position": 4, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "def", 
"hit_page_id": 1111, "hit_position": 1, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "def", 
"hit_page_id": 2222, "hit_position": 2, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "def", 
"hit_page_id": 3333, "hit_position": 3, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "def", 
"hit_page_id": 4444, "hit_position": 4, "clicked": true}
+
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "ghi", 
"hit_page_id": 1111, "hit_position": 1, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "ghi", 
"hit_page_id": 2222, "hit_position": 2, "clicked": true}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "ghi", 
"hit_page_id": 3333, "hit_position": 3, "clicked": false}
+{"wikiid": "foowiki", "norm_query": "zomg", "session_id": "ghi", 
"hit_page_id": 4444, "hit_position": 4, "clicked": false}
diff --git a/mjolnir/test/test_dbn.py b/mjolnir/test/test_dbn.py
new file mode 100644
index 0000000..4c8eb3f
--- /dev/null
+++ b/mjolnir/test/test_dbn.py
@@ -0,0 +1,50 @@
+import os
+import mjolnir.dbn
+
+
+def test_dbn_train(hive_context):
+    df = hive_context.read.json(os.path.join(os.path.dirname(__file__), 
"fixtures/dbn_input.json"))
+    labeled = mjolnir.dbn.train(df, {
+        # Don't use this config for prod, it's specifically for small testing
+        'MAX_ITERATIONS': 1,
+        'DEBUG': False,
+        'PRETTY_LOG': True,
+        'MIN_DOCS_PER_QUERY': 1,
+        'MAX_DOCS_PER_QUERY': 4,
+        'SERP_SIZE': 4,
+        'QUERY_INDEPENDENT_PAGER': False,
+        'DEFAULT_REL': 0.5,
+    }, num_partitions=20)
+    assert len(labeled.columns) == 4
+    assert 'wikiid' in labeled.columns
+    assert 'norm_query' in labeled.columns
+    assert 'hit_page_id' in labeled.columns
+    assert 'relevance' in labeled.columns
+
+    # Make sure we didn't drop data somewhere
+    data = labeled.collect()
+    assert len(data) == 8, "Expecting 4 relevance labels * 2 queries in 
fixtures"
+
+    # Make sure wikiid is kept through the process
+    wikiids = set([row.wikiid for row in data])
+    assert len(wikiids) == 1
+    assert u'foowiki' in wikiids
+
+    # Make sure the set of unique queries is kept
+    queries = set([row.norm_query for row in data])
+    assert len(queries) == 2
+    assert u'test' in queries
+    assert u'zomg' in queries
+
+    # Make sure the dbn is provided data in the right order, by looking at 
what comes out
+    # at the top and bottom of each query. This should also detect if 
something went wrong
+    # with partitioning, causing parts of a query to train in separate DBN's
+    test = sorted([row for row in data if row.norm_query == u'test'], 
key=lambda row: row.relevance, reverse=True)
+    assert test[0].hit_page_id == 1111
+    assert test[3].hit_page_id == 3333
+
+    zomg = sorted([row for row in data if row.norm_query == u'zomg'], 
key=lambda row: row.relevance, reverse=True)
+    assert zomg[0].hit_page_id == 4444
+    assert zomg[3].hit_page_id == 1111
+    # page 1111 should have been skipped every time, resulting in a very low 
score
+    assert zomg[3].relevance == 0.1
diff --git a/poc/data_process_dbn.py b/poc/data_process_dbn.py
deleted file mode 100644
index 4395d92..0000000
--- a/poc/data_process_dbn.py
+++ /dev/null
@@ -1,130 +0,0 @@
-from pyspark import SparkContext
-from pyspark.sql import HiveContext
-
-import tempfile
-import json
-import codecs
-
-import clickmodels
-from clickmodels.inference import DbnModel
-from clickmodels.input_reader import InputReader, SessionItem
-
-import config
-from utils import spark_utils
-
-def train_dbn_partition():
-    # Extra wrapper is necessary to ensure we don't try and
-    # import config in the worker node
-    DBN_CONFIG = config.DBN_CONFIG
-    def work(iterator):
-        # Dump iterator into a temp file
-        reader = InputReader(DBN_CONFIG['MIN_DOCS_PER_QUERY'],
-                             DBN_CONFIG['MAX_DOCS_PER_QUERY'],
-                             False,
-                             DBN_CONFIG['SERP_SIZE'],
-                             False,
-                             discard_no_clicks=True)
-
-        # Evil hax to make our temporary file read/write utf-8,
-        # as the queries contain utf-8
-        f = tempfile.TemporaryFile()
-        info = codecs.lookup('utf-8')
-        f = codecs.StreamReaderWriter(f, info.streamreader, info.streamwriter, 
'struct')
-        for row in iterator:
-            results = []
-            clicks = []
-            for hit in sorted(row.hits, key=lambda hit: hit.position):
-                results.append(str(hit.page_id))
-                clicks.append(bool(hit.clicked))
-            f.write('\t'.join([
-                "0", # unused identifier
-                row.norm_query,
-                "0", # region
-                "0", # intent weight
-                json.dumps(results), # displayed hits
-                json.dumps([False] * len(results)), # layout
-                json.dumps(clicks) # clicks
-            ]) + "\n")
-        f.seek(0)
-        sessions = reader(f)
-        del f
-
-
-        dbn_config = DBN_CONFIG.copy()
-        dbn_config['MAX_QUERY_ID'] = reader.current_query_id + 1
-        # Test with a single iteration 
-        #dbn_config['MAX_ITERATIONS'] = 1
-        model = DbnModel((0.9, 0.9, 0.9, 0.9), config=dbn_config)
-        model.train(sessions)
-
-        relevances = []
-        uid_to_url = dict((uid, url) for url, uid in 
reader.url_to_id.iteritems())
-        for (query, region), qid in reader.query_to_id.iteritems():
-            for uid, data in model.urlRelevances[False][qid].iteritems():
-                relevances.append([query, int(uid_to_url[uid]), data['a'] * 
data['s']])
-
-        return relevances
-    return work
-
-def session_to_dbn(row):
-    results = []
-    clicks = []
-    for hit in sorted(row.hits, key=lambda hit: hit.position):
-        results.append(str(hit.page_id))
-        clicks.append(bool(hit.clicked))
-
-    return [row.norm_query, results, clicks]
-
-def prep_dbn(hive):
-    hive.read.parquet(config.CLICK_DATA).registerTempTable('click_data')
-
-    hive.sql("""
-        SELECT
-            norm_query,
-            session_id,
-            hit_page_id,
-            AVG(hit_position) AS hit_position,
-            ARRAY_CONTAINS(COLLECT_LIST(clicked), true) as clicked
-        FROM
-            click_data
-        GROUP BY
-            norm_query,
-            session_id,
-            hit_page_id
-    """).registerTempTable('click_data_by_session')
-
-    return (hive.sql("""
-            SELECT
-                norm_query,
-                COLLECT_LIST(NAMED_STRUCT(
-                    'position', hit_position,
-                    'page_id', hit_page_id,
-                    'clicked', clicked
-                )) AS hits
-            FROM
-                click_data_by_session
-            GROUP BY
-                norm_query,
-                session_id
-        """)
-        # Sort guarantees all sessions for same query
-        # are in same partition
-        .sort('norm_query')
-    )
-
-
-def main():
-    sc, hive = spark_utils._init("LTR: DBN")
-
-    # Attach clickmodels .egg. Very bold assumption it's in an egg...
-    #clickmodels_path = clickmodels.__file__
-    #clickmodels_egg_path = clickmodels_path[:clickmodels_path.find('.egg')+4]
-    #sc.addPyFile(clickmodels_egg_path)
-
-    prep_dbn(hive) \
-        .mapPartitions(train_dbn_partition()) \
-        .toDF(['norm_query', 'hit_page_id', 'relevance']) \
-        .write.parquet(config.DBN_RELEVANCE)
-
-if __name__ == "__main__":
-    main()

-- 
To view, visit https://gerrit.wikimedia.org/r/347038
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: merged
Gerrit-Change-Id: Id0dd1cb267b90b0aaca096f471e5b35ebd88f156
Gerrit-PatchSet: 6
Gerrit-Project: search/MjoLniR
Gerrit-Branch: master
Gerrit-Owner: EBernhardson <ebernhard...@wikimedia.org>
Gerrit-Reviewer: DCausse <dcau...@wikimedia.org>
Gerrit-Reviewer: EBernhardson <ebernhard...@wikimedia.org>
Gerrit-Reviewer: Gehel <guillaume.leder...@wikimedia.org>
Gerrit-Reviewer: Joal <j...@wikimedia.org>
Gerrit-Reviewer: Smalyshev <smalys...@wikimedia.org>
Gerrit-Reviewer: Tjones <tjo...@wikimedia.org>
Gerrit-Reviewer: Volans <rcocci...@wikimedia.org>

_______________________________________________
MediaWiki-commits mailing list
MediaWiki-commits@lists.wikimedia.org
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits

Reply via email to