Repository: spark Updated Branches: refs/heads/master 26c9d7a0f -> d8662cd90
[SPARK-6164] [ML] CrossValidatorModel should keep stats from fitting Added stats from cross validation as a val in the cross validation model to save them for user access. Author: leahmcguire <lmcgu...@salesforce.com> Closes #5915 from leahmcguire/saveCVmetrics and squashes the following commits: 49b507b [leahmcguire] fixed tyle error 67537b1 [leahmcguire] rebased 85907f0 [leahmcguire] fixed name 59987cc [leahmcguire] changed param name and test according to comments 36e71e3 [leahmcguire] rebasing 4b8223e [leahmcguire] fixed name 4ddffc6 [leahmcguire] changed param name and test according to comments 3a995da [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d8662cd9 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d8662cd9 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d8662cd9 Branch: refs/heads/master Commit: d8662cd909a41575df6e0ea1630d2386d3711240 Parents: 26c9d7a Author: leahmcguire <lmcgu...@salesforce.com> Authored: Wed Jun 3 15:46:38 2015 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Wed Jun 3 15:46:38 2015 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 10 +++++++--- .../org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d8662cd9/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 6434b64..cb29392 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - copyValues(new CrossValidatorModel(uid, bestModel).setParent(this)) + copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM @Experimental class CrossValidatorModel private[ml] ( override val uid: String, - val bestModel: Model[_]) + val bestModel: Model[_], + val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams { override def validateParams(): Unit = { @@ -175,7 +176,10 @@ class CrossValidatorModel private[ml] ( } override def copy(extra: ParamMap): CrossValidatorModel = { - val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]]) + val copied = new CrossValidatorModel( + uid, + bestModel.copy(extra).asInstanceOf[Model[_]], + avgMetrics.clone()) copyValues(copied, extra) } } http://git-wip-us.apache.org/repos/asf/spark/blob/d8662cd9/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 5ba469c..9b3619f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) } test("validateParams should check estimatorParamMaps") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org