EBernhardson has uploaded a new change for review. ( 
https://gerrit.wikimedia.org/r/350333 )

Change subject: [WIP] First stab at cross validation
......................................................................

[WIP] First stab at cross validation

Some code for splitting an input dataframe into pieces for
cross validation. Not complete, but might work.

TODO:
* Lots of stuff
* Can this share code with mjolnir.sampling? The high level concept
  is the same and stratified sampling might be desirable here as well.
  It might also be not necessary here as well.

Change-Id: I7da9f3ec4a1354d6c7fb02eb98d2770740a8e4df
---
A mjolnir/training/cross_validation.py
1 file changed, 100 insertions(+), 0 deletions(-)


  git pull ssh://gerrit.wikimedia.org:29418/search/MjoLniR 
refs/changes/33/350333/1

diff --git a/mjolnir/training/cross_validation.py 
b/mjolnir/training/cross_validation.py
new file mode 100644
index 0000000..9bdb562
--- /dev/null
+++ b/mjolnir/training/cross_validation.py
@@ -0,0 +1,100 @@
+"""
+Support for making test/train or k-fold splits
+"""
+
+from collections import defaultdict
+import mjolnir.spark
+from pyspark.sql import functions as F
+
+
+def split(df, splits, output_column='fold', num_partitions=100):
+    """Assign splits to a dataframe of search results
+
+    Individual hits from the same normalized query are not independent,
+    they should have large overlaps in result sets and relevance labels,
+    so the splitting happens at the normalized query level.
+
+    Although the splitting is happening at the normalized query level, the
+    split percentage is still with respect to the number of rows assigned to
+    each split, not the number of normalized queries. This additionally ensures
+    that the split is equal per wiki, meaning an 80/20 split will result in
+    an 80/20 split for each wiki.
+
+    Parameters
+    ----------
+    df : pyspark.sql.DataFrame
+        Input data frame containing (wikiid, norm_query) columns. If this is
+        expensive to compute it should be cached, as it will be used twice.
+    splits: list
+        List of percentages, summing to 1, to split the input dataframe
+        into.
+    output_column : str, optional
+        Name of the new column indicating the split
+    num_partitions : int, optional
+        Sets the number of partitions to split with. Each partition needs
+        to be some minimum size for averages to work out to an evenly split
+        final set. (Default: 100)
+
+    Returns
+    -------
+    pyspark.sql.DataFrame
+        Input dataframe with split indices assigned to a new column
+    """
+    # General sanity check on provided splits. We could attempt
+    # to normalize instead of fail, but this is good enough.
+    assert abs(1 - sum(splits)) < 0.01
+
+    mjolnir.spark.assert_columns(df, ['norm_query'])
+
+    def split_partition(rows):
+        # Current number of items per split
+        split_counts = defaultdict(lambda: [0] * len(splits))
+        # starting at 1 prevents div by zero. Using a float allows later
+        # division to work as expected.
+        processed = defaultdict(lambda: 1.)
+        for row in rows:
+            # Assign row to first available split that has less than
+            # the desired weight
+            for i, percent in enumerate(splits):
+                if split_counts[row.wikiid][i] / processed[row.wikiid] < 
percent:
+                    split_counts[row.wikiid][i] += row.weight
+                    yield (row.wikiid, row.norm_query, i)
+                    break
+            # If no split found assign to first split
+            else:
+                split_counts[row.wikiid][0] += row.weight
+                yield (row.wikiid, row.norm_query, 0)
+            processed[row.wikiid] += row.weight
+
+    df_splits = (
+        df
+        .groupBy('wikiid', 'norm_query')
+        .agg(F.count(F.lit(1)).alias('weight'))
+        # Could we guess the correct number of partitions instead? I'm not
+        # sure though how it should be decided, and would require taking
+        # an extra pass over the data.
+        .coalesce(num_partitions)
+        .mapPartitions(split_partition)
+        .toDF(['wikiid', 'norm_query', output_column]))
+
+    return df.join(df_splits, how='inner', on=['wikiid', 'norm_query'])
+
+
+def group_k_fold(df, num_folds, num_partitions=100, output_column='fold'):
+    """
+    Generates group k-fold splits. The fold a row belongs to is
+    assigned to the column identified by the output_column parameter.
+
+    Parameters
+    ----------
+    df : pyspark.sql.DataFrame
+    num_folds : int
+    test_folds : int, optional
+    vali_folds : int, optional
+    num_partitions : int, optional
+
+    Yields
+    ------
+    dict
+    """
+    return split(df, [1. / num_folds] * num_folds, output_column, 
num_partitions).cache()

-- 
To view, visit https://gerrit.wikimedia.org/r/350333
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: newchange
Gerrit-Change-Id: I7da9f3ec4a1354d6c7fb02eb98d2770740a8e4df
Gerrit-PatchSet: 1
Gerrit-Project: search/MjoLniR
Gerrit-Branch: master
Gerrit-Owner: EBernhardson <ebernhard...@wikimedia.org>

_______________________________________________
MediaWiki-commits mailing list
MediaWiki-commits@lists.wikimedia.org
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits

Reply via email to