Repository: spark
Updated Branches:
  refs/heads/master 9657ee878 -> 3a44aebd0


[SPARK-9690][ML][PYTHON] pyspark CrossValidator random seed

Extend CrossValidator with HasSeed in PySpark.

This PR replaces [https://github.com/apache/spark/pull/7997]

CC: yanboliang thunterdb mmenestret  Would one of you mind taking a look?  
Thanks!

Author: Joseph K. Bradley <jos...@databricks.com>
Author: Martin MENESTRET <mmenest...@ippon.fr>

Closes #10268 from jkbradley/pyspark-cv-seed.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3a44aebd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3a44aebd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3a44aebd

Branch: refs/heads/master
Commit: 3a44aebd0c5331f6ff00734fa44ef63f8d18cfbb
Parents: 9657ee8
Author: Martin Menestret <martinmenest...@gmail.com>
Authored: Wed Dec 16 14:05:35 2015 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Wed Dec 16 14:05:35 2015 -0800

----------------------------------------------------------------------
 python/pyspark/ml/tuning.py | 20 +++++++++++++-------
 1 file changed, 13 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3a44aebd/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 705ee53..08f8db5 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -19,8 +19,9 @@ import itertools
 import numpy as np
 
 from pyspark import since
-from pyspark.ml.param import Params, Param
 from pyspark.ml import Estimator, Model
+from pyspark.ml.param import Params, Param
+from pyspark.ml.param.shared import HasSeed
 from pyspark.ml.util import keyword_only
 from pyspark.sql.functions import rand
 
@@ -89,7 +90,7 @@ class ParamGridBuilder(object):
         return [dict(zip(keys, prod)) for prod in 
itertools.product(*grid_values)]
 
 
-class CrossValidator(Estimator):
+class CrossValidator(Estimator, HasSeed):
     """
     K-fold cross validation.
 
@@ -129,9 +130,11 @@ class CrossValidator(Estimator):
     numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross 
validation")
 
     @keyword_only
-    def __init__(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3):
+    def __init__(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3,
+                 seed=None):
         """
-        __init__(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3)
+        __init__(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3,\
+                 seed=None)
         """
         super(CrossValidator, self).__init__()
         #: param for estimator to be cross-validated
@@ -151,9 +154,11 @@ class CrossValidator(Estimator):
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3):
+    def setParams(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3,
+                  seed=None):
         """
-        setParams(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3):
+        setParams(self, estimator=None, estimatorParamMaps=None, 
evaluator=None, numFolds=3,\
+                  seed=None):
         Sets params for cross validator.
         """
         kwargs = self.setParams._input_kwargs
@@ -225,9 +230,10 @@ class CrossValidator(Estimator):
         numModels = len(epm)
         eva = self.getOrDefault(self.evaluator)
         nFolds = self.getOrDefault(self.numFolds)
+        seed = self.getOrDefault(self.seed)
         h = 1.0 / nFolds
         randCol = self.uid + "_rand"
-        df = dataset.select("*", rand(0).alias(randCol))
+        df = dataset.select("*", rand(seed).alias(randCol))
         metrics = np.zeros(numModels)
         for i in range(nFolds):
             validateLB = i * h


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to