Repository: spark
Updated Branches:
  refs/heads/master ccda75b0d -> 30fcdc038


[SPARK-22922][ML][PYSPARK] Pyspark portion of the fit-multiple API

## What changes were proposed in this pull request?

Adding fitMultiple API to `Estimator` with default implementation. Also update 
have ml.tuning meta-estimators use this API.

## How was this patch tested?

Unit tests.

Author: Bago Amirbekian <b...@databricks.com>

Closes #20058 from MrBago/python-fitMultiple.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/30fcdc03
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/30fcdc03
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/30fcdc03

Branch: refs/heads/master
Commit: 30fcdc0380de4f107977d39d067b07e149ab2cb1
Parents: ccda75b
Author: Bago Amirbekian <b...@databricks.com>
Authored: Fri Dec 29 16:31:25 2017 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Fri Dec 29 16:31:25 2017 -0800

----------------------------------------------------------------------
 python/pyspark/ml/base.py   | 69 ++++++++++++++++++++++++++++++++++++++--
 python/pyspark/ml/tests.py  | 15 +++++++++
 python/pyspark/ml/tuning.py | 44 ++++++++++++++++---------
 3 files changed, 110 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/30fcdc03/python/pyspark/ml/base.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index a6767ce..d4470b5 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -18,13 +18,52 @@
 from abc import ABCMeta, abstractmethod
 
 import copy
+import threading
 
 from pyspark import since
-from pyspark.ml.param import Params
 from pyspark.ml.param.shared import *
 from pyspark.ml.common import inherit_doc
 from pyspark.sql.functions import udf
-from pyspark.sql.types import StructField, StructType, DoubleType
+from pyspark.sql.types import StructField, StructType
+
+
+class _FitMultipleIterator(object):
+    """
+    Used by default implementation of Estimator.fitMultiple to produce models 
in a thread safe
+    iterator. This class handles the simple case of fitMultiple where each 
param map should be
+    fit independently.
+
+    :param fitSingleModel: Function: (int => Model) which fits an estimator to 
a dataset.
+        `fitSingleModel` may be called up to `numModels` times, with a unique 
index each time.
+        Each call to `fitSingleModel` with an index should return the Model 
associated with
+        that index.
+    :param numModel: Number of models this iterator should produce.
+
+    See Estimator.fitMultiple for more info.
+    """
+    def __init__(self, fitSingleModel, numModels):
+        """
+
+        """
+        self.fitSingleModel = fitSingleModel
+        self.numModel = numModels
+        self.counter = 0
+        self.lock = threading.Lock()
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        with self.lock:
+            index = self.counter
+            if index >= self.numModel:
+                raise StopIteration("No models remaining.")
+            self.counter += 1
+        return index, self.fitSingleModel(index)
+
+    def next(self):
+        """For python2 compatibility."""
+        return self.__next__()
 
 
 @inherit_doc
@@ -47,6 +86,27 @@ class Estimator(Params):
         """
         raise NotImplementedError()
 
+    @since("2.3.0")
+    def fitMultiple(self, dataset, paramMaps):
+        """
+        Fits a model to the input dataset for each param map in `paramMaps`.
+
+        :param dataset: input dataset, which is an instance of 
:py:class:`pyspark.sql.DataFrame`.
+        :param paramMaps: A Sequence of param maps.
+        :return: A thread safe iterable which contains one model for each 
param map. Each
+                 call to `next(modelIterator)` will return `(index, model)` 
where model was fit
+                 using `paramMaps[index]`. `index` values may not be 
sequential.
+
+        .. note:: DeveloperApi
+        .. note:: Experimental
+        """
+        estimator = self.copy()
+
+        def fitSingleModel(index):
+            return estimator.fit(dataset, paramMaps[index])
+
+        return _FitMultipleIterator(fitSingleModel, len(paramMaps))
+
     @since("1.3.0")
     def fit(self, dataset, params=None):
         """
@@ -61,7 +121,10 @@ class Estimator(Params):
         if params is None:
             params = dict()
         if isinstance(params, (list, tuple)):
-            return [self.fit(dataset, paramMap) for paramMap in params]
+            models = [None] * len(params)
+            for index, model in self.fitMultiple(dataset, params):
+                models[index] = model
+            return models
         elif isinstance(params, dict):
             if params:
                 return self.copy(params)._fit(dataset)

http://git-wip-us.apache.org/repos/asf/spark/blob/30fcdc03/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index afcb088..1af2b91 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -2380,6 +2380,21 @@ class UnaryTransformerTests(SparkSessionTestCase):
             self.assertEqual(res.input + shiftVal, res.output)
 
 
+class EstimatorTest(unittest.TestCase):
+
+    def testDefaultFitMultiple(self):
+        N = 4
+        data = MockDataset()
+        estimator = MockEstimator()
+        params = [{estimator.fake: i} for i in range(N)]
+        modelIter = estimator.fitMultiple(data, params)
+        indexList = []
+        for index, model in modelIter:
+            self.assertEqual(model.getFake(), index)
+            indexList.append(index)
+        self.assertEqual(sorted(indexList), list(range(N)))
+
+
 if __name__ == "__main__":
     from pyspark.ml.tests import *
     if xmlrunner:

http://git-wip-us.apache.org/repos/asf/spark/blob/30fcdc03/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 4735113..6c0cad6 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -31,6 +31,28 @@ __all__ = ['ParamGridBuilder', 'CrossValidator', 
'CrossValidatorModel', 'TrainVa
            'TrainValidationSplitModel']
 
 
+def _parallelFitTasks(est, train, eva, validation, epm):
+    """
+    Creates a list of callables which can be called from different threads to 
fit and evaluate
+    an estimator in parallel. Each callable returns an `(index, metric)` pair.
+
+    :param est: Estimator, the estimator to be fit.
+    :param train: DataFrame, training data set, used for fitting.
+    :param eva: Evaluator, used to compute `metric`
+    :param validation: DataFrame, validation data set, used for evaluation.
+    :param epm: Sequence of ParamMap, params maps to be used during fitting & 
evaluation.
+    :return: (int, float), an index into `epm` and the associated metric value.
+    """
+    modelIter = est.fitMultiple(train, epm)
+
+    def singleTask():
+        index, model = next(modelIter)
+        metric = eva.evaluate(model.transform(validation, epm[index]))
+        return index, metric
+
+    return [singleTask] * len(epm)
+
+
 class ParamGridBuilder(object):
     r"""
     Builder for a param grid used in grid search-based model selection.
@@ -266,15 +288,9 @@ class CrossValidator(Estimator, ValidatorParams, 
HasParallelism, MLReadable, MLW
             validation = df.filter(condition).cache()
             train = df.filter(~condition).cache()
 
-            def singleTrain(paramMap):
-                model = est.fit(train, paramMap)
-                # TODO: duplicate evaluator to take extra params from input
-                metric = eva.evaluate(model.transform(validation, paramMap))
-                return metric
-
-            currentFoldMetrics = pool.map(singleTrain, epm)
-            for j in range(numModels):
-                metrics[j] += (currentFoldMetrics[j] / nFolds)
+            tasks = _parallelFitTasks(est, train, eva, validation, epm)
+            for j, metric in pool.imap_unordered(lambda f: f(), tasks):
+                metrics[j] += (metric / nFolds)
             validation.unpersist()
             train.unpersist()
 
@@ -523,13 +539,11 @@ class TrainValidationSplit(Estimator, ValidatorParams, 
HasParallelism, MLReadabl
         validation = df.filter(condition).cache()
         train = df.filter(~condition).cache()
 
-        def singleTrain(paramMap):
-            model = est.fit(train, paramMap)
-            metric = eva.evaluate(model.transform(validation, paramMap))
-            return metric
-
+        tasks = _parallelFitTasks(est, train, eva, validation, epm)
         pool = ThreadPool(processes=min(self.getParallelism(), numModels))
-        metrics = pool.map(singleTrain, epm)
+        metrics = [None] * numModels
+        for j, metric in pool.imap_unordered(lambda f: f(), tasks):
+            metrics[j] = metric
         train.unpersist()
         validation.unpersist()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to