Repository: spark
Updated Branches:
  refs/heads/master 57ec27dd7 -> 70fe55886


[SPARK-9847] [ML] Modified copyValues to distinguish between default, explicit 
param values

>From JIRA: Currently, Params.copyValues copies default parameter values to the 
>paramMap of the target instance, rather than the defaultParamMap. It should 
>copy to the defaultParamMap because explicitly setting a parameter can change 
>the semantics.
This issue arose in SPARK-9789, where 2 params "threshold" and "thresholds" for 
LogisticRegression can have mutually exclusive values. If thresholds is set, 
then fit() will copy the default value of threshold as well, easily resulting 
in inconsistent settings for the 2 params.

CC: mengxr

Author: Joseph K. Bradley <jos...@databricks.com>

Closes #8115 from jkbradley/copyvalues-fix.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/70fe5588
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/70fe5588
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/70fe5588

Branch: refs/heads/master
Commit: 70fe558867ccb4bcff6ec673438b03608bb02252
Parents: 57ec27d
Author: Joseph K. Bradley <jos...@databricks.com>
Authored: Wed Aug 12 10:48:52 2015 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Wed Aug 12 10:48:52 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/param/params.scala | 19 ++++++++++++++++---
 .../org/apache/spark/ml/param/ParamsSuite.scala  |  8 ++++++++
 2 files changed, 24 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/70fe5588/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala 
b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index d68f5ff..91c0a56 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -559,13 +559,26 @@ trait Params extends Identifiable with Serializable {
 
   /**
    * Copies param values from this instance to another instance for params 
shared by them.
-   * @param to the target instance
-   * @param extra extra params to be copied
+   *
+   * This handles default Params and explicitly set Params separately.
+   * Default Params are copied from and to [[defaultParamMap]], and explicitly 
set Params are
+   * copied from and to [[paramMap]].
+   * Warning: This implicitly assumes that this [[Params]] instance and the 
target instance
+   *          share the same set of default Params.
+   *
+   * @param to the target instance, which should work with the same set of 
default Params as this
+   *           source instance
+   * @param extra extra params to be copied to the target's [[paramMap]]
    * @return the target instance with param values copied
    */
   protected def copyValues[T <: Params](to: T, extra: ParamMap = 
ParamMap.empty): T = {
-    val map = extractParamMap(extra)
+    val map = paramMap ++ extra
     params.foreach { param =>
+      // copy default Params
+      if (defaultParamMap.contains(param) && to.hasParam(param.name)) {
+        to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param))
+      }
+      // copy explicitly set Params
       if (map.contains(param) && to.hasParam(param.name)) {
         to.set(param.name, map(param))
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/70fe5588/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 050d417..be95638 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -200,6 +200,14 @@ class ParamsSuite extends SparkFunSuite {
     val inArray = ParamValidators.inArray[Int](Array(1, 2))
     assert(inArray(1) && inArray(2) && !inArray(0))
   }
+
+  test("Params.copyValues") {
+    val t = new TestParams()
+    val t2 = t.copy(ParamMap.empty)
+    assert(!t2.isSet(t2.maxIter))
+    val t3 = t.copy(ParamMap(t.maxIter -> 20))
+    assert(t3.isSet(t3.maxIter))
+  }
 }
 
 object ParamsSuite extends SparkFunSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to