Repository: spark Updated Branches: refs/heads/master 4816c2ef5 -> 9434280cf
[SPARK-20861][ML][PYTHON] Delegate looping over paramMaps to estimators Changes: pyspark.ml Estimators can take either a list of param maps or a dict of params. This change allows the CrossValidator and TrainValidationSplit Estimators to pass through lists of param maps to the underlying estimators so that those estimators can handle parallelization when appropriate (eg distributed hyper parameter tuning). Testing: Existing unit tests. Author: Bago Amirbekian <b...@databricks.com> Closes #18077 from MrBago/delegate_params. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9434280c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9434280c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9434280c Branch: refs/heads/master Commit: 9434280cfd1db94dc9d52bb0ace8283e710e3124 Parents: 4816c2e Author: Bago Amirbekian <b...@databricks.com> Authored: Tue May 23 20:56:01 2017 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Tue May 23 20:56:01 2017 -0700 ---------------------------------------------------------------------- python/pyspark/ml/tuning.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9434280c/python/pyspark/ml/tuning.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ffeb445..b648582 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,14 +18,11 @@ import itertools import numpy as np -from pyspark import SparkContext from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed -from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand -from pyspark.ml.common import inherit_doc, _py2java __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit', 'TrainValidationSplitModel'] @@ -232,8 +229,9 @@ class CrossValidator(Estimator, ValidatorParams): condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB) validation = df.filter(condition) train = df.filter(~condition) + models = est.fit(train, epm) for j in range(numModels): - model = est.fit(train, epm[j]) + model = models[j] # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) metrics[j] += metric/nFolds @@ -388,8 +386,9 @@ class TrainValidationSplit(Estimator, ValidatorParams): condition = (df[randCol] >= tRatio) validation = df.filter(condition) train = df.filter(~condition) + models = est.fit(train, epm) for j in range(numModels): - model = est.fit(train, epm[j]) + model = models[j] metric = eva.evaluate(model.transform(validation, epm[j])) metrics[j] += metric if eva.isLargerBetter(): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org