Repository: spark Updated Branches: refs/heads/master 7b6dc29d0 -> 860dc7f2f
[SPARK-9694][ML] Add random seed Param to Scala CrossValidator Add random seed Param to Scala CrossValidator Author: Yanbo Liang <yblia...@gmail.com> Closes #9108 from yanboliang/spark-9694. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/860dc7f2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/860dc7f2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/860dc7f2 Branch: refs/heads/master Commit: 860dc7f2f8dd01f2562ba83b7af27ba29d91cb62 Parents: 7b6dc29 Author: Yanbo Liang <yblia...@gmail.com> Authored: Wed Dec 16 11:05:37 2015 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Wed Dec 16 11:05:37 2015 -0800 ---------------------------------------------------------------------- .../org/apache/spark/ml/tuning/CrossValidator.scala | 11 ++++++++--- .../main/scala/org/apache/spark/mllib/util/MLUtils.scala | 8 ++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/860dc7f2/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5c09f1a..40f8857 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -29,8 +29,9 @@ import org.apache.spark.ml.classification.OneVsRestParams import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -39,7 +40,7 @@ import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ -private[ml] trait CrossValidatorParams extends ValidatorParams { +private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 @@ -85,6 +86,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("1.2.0") def setNumFolds(value: Int): this.type = set(numFolds, value) + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + @Since("1.4.0") override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema @@ -95,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0) + val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() http://git-wip-us.apache.org/repos/asf/spark/blob/860dc7f2/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 414ea99..4c9151f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -265,6 +265,14 @@ object MLUtils { */ @Since("1.0.0") def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { + kFold(rdd, numFolds, seed.toLong) + } + + /** + * Version of [[kFold()]] taking a Long seed. + */ + @Since("2.0.0") + def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Long): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat (1 to numFolds).map { fold => val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org