Repository: spark Updated Branches: refs/heads/master 9f019c722 -> 32cdc815c
[SPARK-6940] [MLLIB] Add CrossValidator to Python ML pipeline API Since CrossValidator is a meta algorithm, we copy the implementation in Python. jkbradley Author: Xiangrui Meng <m...@databricks.com> Closes #5926 from mengxr/SPARK-6940 and squashes the following commits: 6af181f [Xiangrui Meng] add TODOs 8285134 [Xiangrui Meng] update doc 060f7c3 [Xiangrui Meng] update doctest acac727 [Xiangrui Meng] add keyword args cdddecd [Xiangrui Meng] add CrossValidator in Python Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/32cdc815 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/32cdc815 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/32cdc815 Branch: refs/heads/master Commit: 32cdc815c6fc19b5c8c4eca35f88a61302d67cd5 Parents: 9f019c7 Author: Xiangrui Meng <m...@databricks.com> Authored: Wed May 6 01:28:43 2015 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Wed May 6 01:28:43 2015 -0700 ---------------------------------------------------------------------- .../apache/spark/ml/tuning/CrossValidator.scala | 7 +- python/pyspark/ml/pipeline.py | 13 +- python/pyspark/ml/tuning.py | 183 ++++++++++++++++++- python/pyspark/ml/wrapper.py | 4 +- 4 files changed, 199 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/32cdc815/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index cee2aa6..9208127 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -52,10 +52,12 @@ private[ml] trait CrossValidatorParams extends Params { def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps) /** - * param for the evaluator for selection + * param for the evaluator used to select hyper-parameters that maximize the cross-validated + * metric * @group param */ - val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") + val evaluator: Param[Evaluator] = new Param(this, "evaluator", + "evaluator used to select hyper-parameters that maximize the cross-validated metric") /** @group getParam */ def getEvaluator: Evaluator = $(evaluator) @@ -120,6 +122,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP trainingDataset.unpersist() var i = 0 while (i < numModels) { + // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) logDebug(s"Got metric $metric for model trained with ${epm(i)}.") metrics(i) += metric http://git-wip-us.apache.org/repos/asf/spark/blob/32cdc815/python/pyspark/ml/pipeline.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 7b875e4..c1b2077 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -22,7 +22,7 @@ from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc -__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator'] +__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator', 'Model'] @inherit_doc @@ -71,6 +71,15 @@ class Transformer(Params): @inherit_doc +class Model(Transformer): + """ + Abstract class for models that are fitted by estimators. + """ + + __metaclass__ = ABCMeta + + +@inherit_doc class Pipeline(Estimator): """ A simple pipeline, which acts as an estimator. A Pipeline consists @@ -154,7 +163,7 @@ class Pipeline(Estimator): @inherit_doc -class PipelineModel(Transformer): +class PipelineModel(Model): """ Represents a compiled pipeline with transformers and fitted models. """ http://git-wip-us.apache.org/repos/asf/spark/blob/32cdc815/python/pyspark/ml/tuning.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 1773ab5..f6cf2c3 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -16,8 +16,14 @@ # import itertools +import numpy as np -__all__ = ['ParamGridBuilder'] +from pyspark.ml.param import Params, Param +from pyspark.ml import Estimator, Model +from pyspark.ml.util import keyword_only +from pyspark.sql.functions import rand + +__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel'] class ParamGridBuilder(object): @@ -79,6 +85,179 @@ class ParamGridBuilder(object): return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] +class CrossValidator(Estimator): + """ + K-fold cross validation. + + >>> from pyspark.ml.classification import LogisticRegression + >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator + >>> from pyspark.mllib.linalg import Vectors + >>> dataset = sqlContext.createDataFrame( + ... [(Vectors.dense([0.0, 1.0]), 0.0), + ... (Vectors.dense([1.0, 2.0]), 1.0), + ... (Vectors.dense([0.55, 3.0]), 0.0), + ... (Vectors.dense([0.45, 4.0]), 1.0), + ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10, + ... ["features", "label"]) + >>> lr = LogisticRegression() + >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build() + >>> evaluator = BinaryClassificationEvaluator() + >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + >>> cvModel = cv.fit(dataset) + >>> expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset) + >>> cvModel.transform(dataset).collect() == expected.collect() + True + """ + + # a placeholder to make it appear in the generated doc + estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") + + # a placeholder to make it appear in the generated doc + estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") + + # a placeholder to make it appear in the generated doc + evaluator = Param( + Params._dummy(), "evaluator", + "evaluator used to select hyper-parameters that maximize the cross-validated metric") + + # a placeholder to make it appear in the generated doc + numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") + + @keyword_only + def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + """ + __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) + """ + super(CrossValidator, self).__init__() + #: param for estimator to be cross-validated + self.estimator = Param(self, "estimator", "estimator to be cross-validated") + #: param for estimator param maps + self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps") + #: param for the evaluator used to select hyper-parameters that + #: maximize the cross-validated metric + self.evaluator = Param( + self, "evaluator", + "evaluator used to select hyper-parameters that maximize the cross-validated metric") + #: param for number of folds for cross validation + self.numFolds = Param(self, "numFolds", "number of folds for cross validation") + self._setDefault(numFolds=3) + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + @keyword_only + def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + """ + setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + Sets params for cross validator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setEstimator(self, value): + """ + Sets the value of :py:attr:`estimator`. + """ + self.paramMap[self.estimator] = value + return self + + def getEstimator(self): + """ + Gets the value of estimator or its default value. + """ + return self.getOrDefault(self.estimator) + + def setEstimatorParamMaps(self, value): + """ + Sets the value of :py:attr:`estimatorParamMaps`. + """ + self.paramMap[self.estimatorParamMaps] = value + return self + + def getEstimatorParamMaps(self): + """ + Gets the value of estimatorParamMaps or its default value. + """ + return self.getOrDefault(self.estimatorParamMaps) + + def setEvaluator(self, value): + """ + Sets the value of :py:attr:`evaluator`. + """ + self.paramMap[self.evaluator] = value + return self + + def getEvaluator(self): + """ + Gets the value of evaluator or its default value. + """ + return self.getOrDefault(self.evaluator) + + def setNumFolds(self, value): + """ + Sets the value of :py:attr:`numFolds`. + """ + self.paramMap[self.numFolds] = value + return self + + def getNumFolds(self): + """ + Gets the value of numFolds or its default value. + """ + return self.getOrDefault(self.numFolds) + + def fit(self, dataset, params={}): + paramMap = self.extractParamMap(params) + est = paramMap[self.estimator] + epm = paramMap[self.estimatorParamMaps] + numModels = len(epm) + eva = paramMap[self.evaluator] + nFolds = paramMap[self.numFolds] + h = 1.0 / nFolds + randCol = self.uid + "_rand" + df = dataset.select("*", rand(0).alias(randCol)) + metrics = np.zeros(numModels) + for i in range(nFolds): + validateLB = i * h + validateUB = (i + 1) * h + condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB) + validation = df.filter(condition) + train = df.filter(~condition) + for j in range(numModels): + model = est.fit(train, epm[j]) + # TODO: duplicate evaluator to take extra params from input + metric = eva.evaluate(model.transform(validation, epm[j])) + metrics[j] += metric + bestIndex = np.argmax(metrics) + bestModel = est.fit(dataset, epm[bestIndex]) + return CrossValidatorModel(bestModel) + + +class CrossValidatorModel(Model): + """ + Model from k-fold cross validation. + """ + + def __init__(self, bestModel): + #: best model from cross validation + self.bestModel = bestModel + + def transform(self, dataset, params={}): + return self.bestModel.transform(dataset, params) + + if __name__ == "__main__": import doctest - doctest.testmod() + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.tuning tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) http://git-wip-us.apache.org/repos/asf/spark/blob/32cdc815/python/pyspark/ml/wrapper.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 73741c4..0634254 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -20,7 +20,7 @@ from abc import ABCMeta from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params -from pyspark.ml.pipeline import Estimator, Transformer, Evaluator +from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model from pyspark.mllib.common import inherit_doc @@ -133,7 +133,7 @@ class JavaTransformer(Transformer, JavaWrapper): @inherit_doc -class JavaModel(JavaTransformer): +class JavaModel(Model, JavaTransformer): """ Base class for :py:class:`Model`s that wrap Java/Scala implementations. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org