Repository: spark Updated Branches: refs/heads/master ccda75b0d -> 30fcdc038
[SPARK-22922][ML][PYSPARK] Pyspark portion of the fit-multiple API ## What changes were proposed in this pull request? Adding fitMultiple API to `Estimator` with default implementation. Also update have ml.tuning meta-estimators use this API. ## How was this patch tested? Unit tests. Author: Bago Amirbekian <b...@databricks.com> Closes #20058 from MrBago/python-fitMultiple. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/30fcdc03 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/30fcdc03 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/30fcdc03 Branch: refs/heads/master Commit: 30fcdc0380de4f107977d39d067b07e149ab2cb1 Parents: ccda75b Author: Bago Amirbekian <b...@databricks.com> Authored: Fri Dec 29 16:31:25 2017 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Fri Dec 29 16:31:25 2017 -0800 ---------------------------------------------------------------------- python/pyspark/ml/base.py | 69 ++++++++++++++++++++++++++++++++++++++-- python/pyspark/ml/tests.py | 15 +++++++++ python/pyspark/ml/tuning.py | 44 ++++++++++++++++--------- 3 files changed, 110 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/30fcdc03/python/pyspark/ml/base.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index a6767ce..d4470b5 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -18,13 +18,52 @@ from abc import ABCMeta, abstractmethod import copy +import threading from pyspark import since -from pyspark.ml.param import Params from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc from pyspark.sql.functions import udf -from pyspark.sql.types import StructField, StructType, DoubleType +from pyspark.sql.types import StructField, StructType + + +class _FitMultipleIterator(object): + """ + Used by default implementation of Estimator.fitMultiple to produce models in a thread safe + iterator. This class handles the simple case of fitMultiple where each param map should be + fit independently. + + :param fitSingleModel: Function: (int => Model) which fits an estimator to a dataset. + `fitSingleModel` may be called up to `numModels` times, with a unique index each time. + Each call to `fitSingleModel` with an index should return the Model associated with + that index. + :param numModel: Number of models this iterator should produce. + + See Estimator.fitMultiple for more info. + """ + def __init__(self, fitSingleModel, numModels): + """ + + """ + self.fitSingleModel = fitSingleModel + self.numModel = numModels + self.counter = 0 + self.lock = threading.Lock() + + def __iter__(self): + return self + + def __next__(self): + with self.lock: + index = self.counter + if index >= self.numModel: + raise StopIteration("No models remaining.") + self.counter += 1 + return index, self.fitSingleModel(index) + + def next(self): + """For python2 compatibility.""" + return self.__next__() @inherit_doc @@ -47,6 +86,27 @@ class Estimator(Params): """ raise NotImplementedError() + @since("2.3.0") + def fitMultiple(self, dataset, paramMaps): + """ + Fits a model to the input dataset for each param map in `paramMaps`. + + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`. + :param paramMaps: A Sequence of param maps. + :return: A thread safe iterable which contains one model for each param map. Each + call to `next(modelIterator)` will return `(index, model)` where model was fit + using `paramMaps[index]`. `index` values may not be sequential. + + .. note:: DeveloperApi + .. note:: Experimental + """ + estimator = self.copy() + + def fitSingleModel(index): + return estimator.fit(dataset, paramMaps[index]) + + return _FitMultipleIterator(fitSingleModel, len(paramMaps)) + @since("1.3.0") def fit(self, dataset, params=None): """ @@ -61,7 +121,10 @@ class Estimator(Params): if params is None: params = dict() if isinstance(params, (list, tuple)): - return [self.fit(dataset, paramMap) for paramMap in params] + models = [None] * len(params) + for index, model in self.fitMultiple(dataset, params): + models[index] = model + return models elif isinstance(params, dict): if params: return self.copy(params)._fit(dataset) http://git-wip-us.apache.org/repos/asf/spark/blob/30fcdc03/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index afcb088..1af2b91 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2380,6 +2380,21 @@ class UnaryTransformerTests(SparkSessionTestCase): self.assertEqual(res.input + shiftVal, res.output) +class EstimatorTest(unittest.TestCase): + + def testDefaultFitMultiple(self): + N = 4 + data = MockDataset() + estimator = MockEstimator() + params = [{estimator.fake: i} for i in range(N)] + modelIter = estimator.fitMultiple(data, params) + indexList = [] + for index, model in modelIter: + self.assertEqual(model.getFake(), index) + indexList.append(index) + self.assertEqual(sorted(indexList), list(range(N))) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: http://git-wip-us.apache.org/repos/asf/spark/blob/30fcdc03/python/pyspark/ml/tuning.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 4735113..6c0cad6 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -31,6 +31,28 @@ __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainVa 'TrainValidationSplitModel'] +def _parallelFitTasks(est, train, eva, validation, epm): + """ + Creates a list of callables which can be called from different threads to fit and evaluate + an estimator in parallel. Each callable returns an `(index, metric)` pair. + + :param est: Estimator, the estimator to be fit. + :param train: DataFrame, training data set, used for fitting. + :param eva: Evaluator, used to compute `metric` + :param validation: DataFrame, validation data set, used for evaluation. + :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation. + :return: (int, float), an index into `epm` and the associated metric value. + """ + modelIter = est.fitMultiple(train, epm) + + def singleTask(): + index, model = next(modelIter) + metric = eva.evaluate(model.transform(validation, epm[index])) + return index, metric + + return [singleTask] * len(epm) + + class ParamGridBuilder(object): r""" Builder for a param grid used in grid search-based model selection. @@ -266,15 +288,9 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW validation = df.filter(condition).cache() train = df.filter(~condition).cache() - def singleTrain(paramMap): - model = est.fit(train, paramMap) - # TODO: duplicate evaluator to take extra params from input - metric = eva.evaluate(model.transform(validation, paramMap)) - return metric - - currentFoldMetrics = pool.map(singleTrain, epm) - for j in range(numModels): - metrics[j] += (currentFoldMetrics[j] / nFolds) + tasks = _parallelFitTasks(est, train, eva, validation, epm) + for j, metric in pool.imap_unordered(lambda f: f(), tasks): + metrics[j] += (metric / nFolds) validation.unpersist() train.unpersist() @@ -523,13 +539,11 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl validation = df.filter(condition).cache() train = df.filter(~condition).cache() - def singleTrain(paramMap): - model = est.fit(train, paramMap) - metric = eva.evaluate(model.transform(validation, paramMap)) - return metric - + tasks = _parallelFitTasks(est, train, eva, validation, epm) pool = ThreadPool(processes=min(self.getParallelism(), numModels)) - metrics = pool.map(singleTrain, epm) + metrics = [None] * numModels + for j, metric in pool.imap_unordered(lambda f: f(), tasks): + metrics[j] = metric train.unpersist() validation.unpersist() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org