Github user MrBago commented on a diff in the pull request: https://github.com/apache/spark/pull/20058#discussion_r159023958 --- Diff: python/pyspark/ml/base.py --- @@ -47,6 +86,28 @@ def _fit(self, dataset): """ raise NotImplementedError() + @since("2.3.0") + def fitMultiple(self, dataset, params): + """ + Fits a model to the input dataset for each param map in params. + + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`. + :param params: A Sequence of param maps. + :return: A thread safe iterable which contains one model for each param map. Each + call to `next(modelIterator)` will return `(index, model)` where model was fit + using `params[index]`. Params maps may be fit in an order different than their + order in params. + + .. note:: DeveloperApi + .. note:: Experimental + """ + estimator = self.copy() + + def fitSingleModel(index): + return estimator.fit(dataset, params[index]) + + return FitMultipleIterator(fitSingleModel, len(params)) --- End diff -- The idea is you should be able to do something like this: ``` pool = ... modelIter = estimator.fitMultiple(params) rng = range(len(params)) for index, model in pool.imap_unordered(lambda _: next(modelIter), rng): pass ``` That's pretty much how I've set up corss validator to use it, https://github.com/apache/spark/pull/20058/files/fe3d6bddc3e9e50febf706d7f22007b1e0d58de3#diff-cbc8c36bfdd245e4e4d5bd27f9b95359R292 The reason for set it up this way is so that, when appropriate, Estimators can implement their own optimized `fitMultiple` methods that just need to return an "iterator", A class with `__iter__` and `__next__`. For examples models that use `maxIter` and `maxDepth` params.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org