Repository: spark
Updated Branches:
  refs/heads/master 26c9d7a0f -> d8662cd90


[SPARK-6164] [ML] CrossValidatorModel should keep stats from fitting

Added stats from cross validation as a val in the cross validation model to 
save them for user access.

Author: leahmcguire <lmcgu...@salesforce.com>

Closes #5915 from leahmcguire/saveCVmetrics and squashes the following commits:

49b507b [leahmcguire] fixed tyle error
67537b1 [leahmcguire] rebased
85907f0 [leahmcguire] fixed name
59987cc [leahmcguire] changed param name and test according to comments
36e71e3 [leahmcguire] rebasing
4b8223e [leahmcguire] fixed name
4ddffc6 [leahmcguire] changed param name and test according to comments
3a995da [leahmcguire] Added stats from cross validation as a val in the cross 
validation model to save them for user access


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d8662cd9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d8662cd9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d8662cd9

Branch: refs/heads/master
Commit: d8662cd909a41575df6e0ea1630d2386d3711240
Parents: 26c9d7a
Author: leahmcguire <lmcgu...@salesforce.com>
Authored: Wed Jun 3 15:46:38 2015 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Wed Jun 3 15:46:38 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 10 +++++++---
 .../org/apache/spark/ml/tuning/CrossValidatorSuite.scala  |  1 +
 2 files changed, 8 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d8662cd9/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 6434b64..cb29392 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends 
Estimator[CrossValidatorM
     logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
     logInfo(s"Best cross-validation metric: $bestMetric.")
     val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
-    copyValues(new CrossValidatorModel(uid, bestModel).setParent(this))
+    copyValues(new CrossValidatorModel(uid, bestModel, 
metrics).setParent(this))
   }
 
   override def transformSchema(schema: StructType): StructType = {
@@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends 
Estimator[CrossValidatorM
 @Experimental
 class CrossValidatorModel private[ml] (
     override val uid: String,
-    val bestModel: Model[_])
+    val bestModel: Model[_],
+    val avgMetrics: Array[Double])
   extends Model[CrossValidatorModel] with CrossValidatorParams {
 
   override def validateParams(): Unit = {
@@ -175,7 +176,10 @@ class CrossValidatorModel private[ml] (
   }
 
   override def copy(extra: ParamMap): CrossValidatorModel = {
-    val copied = new CrossValidatorModel(uid, 
bestModel.copy(extra).asInstanceOf[Model[_]])
+    val copied = new CrossValidatorModel(
+      uid,
+      bestModel.copy(extra).asInstanceOf[Model[_]],
+      avgMetrics.clone())
     copyValues(copied, extra)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d8662cd9/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 5ba469c..9b3619f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
     assert(parent.getRegParam === 0.001)
     assert(parent.getMaxIter === 10)
+    assert(cvModel.avgMetrics.length === lrParamMaps.length)
   }
 
   test("validateParams should check estimatorParamMaps") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to