Repository: spark
Updated Branches:
  refs/heads/master 93dbfe705 -> 2a903a1ee


[SPARK-19985][ML] Fixed copy method for some ML Models

## What changes were proposed in this pull request?
Some ML Models were using `defaultCopy` which expects a default constructor, 
and others were not setting the parent estimator.  This change fixes these by 
creating a new instance of the model and explicitly setting values and parent.

## How was this patch tested?
Added `MLTestingUtils.checkCopy` to the offending models to tests to verify the 
copy is made and parent is set.

Author: Bryan Cutler <cutl...@gmail.com>

Closes #17326 from BryanCutler/ml-model-copy-error-SPARK-19985.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2a903a1e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2a903a1e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2a903a1e

Branch: refs/heads/master
Commit: 2a903a1eec46e3bd58af0fcbc57e76752d9c18b3
Parents: 93dbfe7
Author: Bryan Cutler <cutl...@gmail.com>
Authored: Mon Apr 3 10:56:54 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Mon Apr 3 10:56:54 2017 +0200

----------------------------------------------------------------------
 .../classification/MultilayerPerceptronClassifier.scala  |  3 ++-
 .../spark/ml/feature/BucketedRandomProjectionLSH.scala   |  5 ++++-
 .../scala/org/apache/spark/ml/feature/MinHashLSH.scala   |  5 ++++-
 .../scala/org/apache/spark/ml/feature/RFormula.scala     |  6 ++++--
 .../MultilayerPerceptronClassifierSuite.scala            |  1 +
 .../ml/feature/BucketedRandomProjectionLSHSuite.scala    |  6 ++++--
 .../org/apache/spark/ml/feature/MinHashLSHSuite.scala    | 11 ++++++++++-
 .../org/apache/spark/ml/feature/RFormulaSuite.scala      |  1 +
 8 files changed, 30 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 95c1337..ec39f96 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -329,7 +329,8 @@ class MultilayerPerceptronClassificationModel private[ml] (
 
   @Since("1.5.0")
   override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel 
= {
-    copyValues(new MultilayerPerceptronClassificationModel(uid, layers, 
weights), extra)
+    val copied = new MultilayerPerceptronClassificationModel(uid, layers, 
weights).setParent(parent)
+    copyValues(copied, extra)
   }
 
   @Since("2.0.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
index cbac163..36a46ca 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -96,7 +96,10 @@ class BucketedRandomProjectionLSHModel private[ml](
   }
 
   @Since("2.1.0")
-  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+  override def copy(extra: ParamMap): BucketedRandomProjectionLSHModel = {
+    val copied = new BucketedRandomProjectionLSHModel(uid, 
randUnitVectors).setParent(parent)
+    copyValues(copied, extra)
+  }
 
   @Since("2.1.0")
   override def write: MLWriter = {

http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
index 620e1fb..145422a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -86,7 +86,10 @@ class MinHashLSHModel private[ml](
   }
 
   @Since("2.1.0")
-  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+  override def copy(extra: ParamMap): MinHashLSHModel = {
+    val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent)
+    copyValues(copied, extra)
+  }
 
   @Since("2.1.0")
   override def write: MLWriter = new 
MinHashLSHModel.MinHashLSHModelWriter(this)

http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 3898986..5a3e292 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -268,8 +268,10 @@ class RFormulaModel private[feature](
   }
 
   @Since("1.5.0")
-  override def copy(extra: ParamMap): RFormulaModel = copyValues(
-    new RFormulaModel(uid, resolvedFormula, pipelineModel))
+  override def copy(extra: ParamMap): RFormulaModel = {
+    val copied = new RFormulaModel(uid, resolvedFormula, 
pipelineModel).setParent(parent)
+    copyValues(copied, extra)
+  }
 
   @Since("2.0.0")
   override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"

http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 41684d9..7700099 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -74,6 +74,7 @@ class MultilayerPerceptronClassifierSuite
       .setMaxIter(100)
       .setSolver("l-bfgs")
     val model = trainer.fit(dataset)
+    MLTestingUtils.checkCopy(model)
     val result = model.transform(dataset)
     val predictionAndLabels = result.select("prediction", "label").collect()
     predictionAndLabels.foreach { case Row(p: Double, l: Double) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
index 91eac9e..cc81da5 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -23,7 +23,7 @@ import breeze.numerics.constants.Pi
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Dataset
@@ -89,10 +89,12 @@ class BucketedRandomProjectionLSHSuite
       .setOutputCol("values")
       .setBucketLength(1.0)
       .setSeed(12345)
-    val unitVectors = brp.fit(dataset).randUnitVectors
+    val brpModel = brp.fit(dataset)
+    val unitVectors = brpModel.randUnitVectors
     unitVectors.foreach { v: Vector =>
       assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14)
     }
+    MLTestingUtils.checkCopy(brpModel)
   }
 
   test("BucketedRandomProjectionLSH: test of LSH property") {

http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
index a2f0093..0ddf097 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Dataset
 
@@ -57,6 +57,15 @@ class MinHashLSHSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defa
     testEstimatorAndModelReadWrite(mh, dataset, settings, settings, 
checkModelData)
   }
 
+  test("Model copy and uid checks") {
+    val mh = new MinHashLSH()
+      .setInputCol("keys")
+      .setOutputCol("values")
+    val model = mh.fit(dataset)
+    assert(mh.uid === model.uid)
+    MLTestingUtils.checkCopy(model)
+  }
+
   test("hashFunction") {
     val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 
2), (3, 0)))
     val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), 
(5, 1.0), (7, 1.0))))

http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index c664460..5cfd59e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -37,6 +37,7 @@ class RFormulaSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
     val formula = new RFormula().setFormula("id ~ v1 + v2")
     val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2")
     val model = formula.fit(original)
+    MLTestingUtils.checkCopy(model)
     val result = model.transform(original)
     val resultSchema = model.transformSchema(original.schema)
     val expected = Seq(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to