Repository: spark Updated Branches: refs/heads/master 93dbfe705 -> 2a903a1ee
[SPARK-19985][ML] Fixed copy method for some ML Models ## What changes were proposed in this pull request? Some ML Models were using `defaultCopy` which expects a default constructor, and others were not setting the parent estimator. This change fixes these by creating a new instance of the model and explicitly setting values and parent. ## How was this patch tested? Added `MLTestingUtils.checkCopy` to the offending models to tests to verify the copy is made and parent is set. Author: Bryan Cutler <cutl...@gmail.com> Closes #17326 from BryanCutler/ml-model-copy-error-SPARK-19985. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2a903a1e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2a903a1e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2a903a1e Branch: refs/heads/master Commit: 2a903a1eec46e3bd58af0fcbc57e76752d9c18b3 Parents: 93dbfe7 Author: Bryan Cutler <cutl...@gmail.com> Authored: Mon Apr 3 10:56:54 2017 +0200 Committer: Nick Pentreath <ni...@za.ibm.com> Committed: Mon Apr 3 10:56:54 2017 +0200 ---------------------------------------------------------------------- .../classification/MultilayerPerceptronClassifier.scala | 3 ++- .../spark/ml/feature/BucketedRandomProjectionLSH.scala | 5 ++++- .../scala/org/apache/spark/ml/feature/MinHashLSH.scala | 5 ++++- .../scala/org/apache/spark/ml/feature/RFormula.scala | 6 ++++-- .../MultilayerPerceptronClassifierSuite.scala | 1 + .../ml/feature/BucketedRandomProjectionLSHSuite.scala | 6 ++++-- .../org/apache/spark/ml/feature/MinHashLSHSuite.scala | 11 ++++++++++- .../org/apache/spark/ml/feature/RFormulaSuite.scala | 1 + 8 files changed, 30 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 95c1337..ec39f96 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -329,7 +329,8 @@ class MultilayerPerceptronClassificationModel private[ml] ( @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { - copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) + val copied = new MultilayerPerceptronClassificationModel(uid, layers, weights).setParent(parent) + copyValues(copied, extra) } @Since("2.0.0") http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index cbac163..36a46ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -96,7 +96,10 @@ class BucketedRandomProjectionLSHModel private[ml]( } @Since("2.1.0") - override def copy(extra: ParamMap): this.type = defaultCopy(extra) + override def copy(extra: ParamMap): BucketedRandomProjectionLSHModel = { + val copied = new BucketedRandomProjectionLSHModel(uid, randUnitVectors).setParent(parent) + copyValues(copied, extra) + } @Since("2.1.0") override def write: MLWriter = { http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 620e1fb..145422a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -86,7 +86,10 @@ class MinHashLSHModel private[ml]( } @Since("2.1.0") - override def copy(extra: ParamMap): this.type = defaultCopy(extra) + override def copy(extra: ParamMap): MinHashLSHModel = { + val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent) + copyValues(copied, extra) + } @Since("2.1.0") override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this) http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 3898986..5a3e292 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -268,8 +268,10 @@ class RFormulaModel private[feature]( } @Since("1.5.0") - override def copy(extra: ParamMap): RFormulaModel = copyValues( - new RFormulaModel(uid, resolvedFormula, pipelineModel)) + override def copy(extra: ParamMap): RFormulaModel = { + val copied = new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(parent) + copyValues(copied, extra) + } @Since("2.0.0") override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)" http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 41684d9..7700099 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -74,6 +74,7 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(dataset) + MLTestingUtils.checkCopy(model) val result = model.transform(dataset) val predictionAndLabels = result.select("prediction", "label").collect() predictionAndLabels.foreach { case Row(p: Double, l: Double) => http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index 91eac9e..cc81da5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -23,7 +23,7 @@ import breeze.numerics.constants.Pi import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -89,10 +89,12 @@ class BucketedRandomProjectionLSHSuite .setOutputCol("values") .setBucketLength(1.0) .setSeed(12345) - val unitVectors = brp.fit(dataset).randUnitVectors + val brpModel = brp.fit(dataset) + val unitVectors = brpModel.randUnitVectors unitVectors.foreach { v: Vector => assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14) } + MLTestingUtils.checkCopy(brpModel) } test("BucketedRandomProjectionLSH: test of LSH property") { http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index a2f0093..0ddf097 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -57,6 +57,15 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } + test("Model copy and uid checks") { + val mh = new MinHashLSH() + .setInputCol("keys") + .setOutputCol("values") + val model = mh.fit(dataset) + assert(mh.uid === model.uid) + MLTestingUtils.checkCopy(model) + } + test("hashFunction") { val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 2), (3, 0))) val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0)))) http://git-wip-us.apache.org/repos/asf/spark/blob/2a903a1e/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index c664460..5cfd59e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -37,6 +37,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val formula = new RFormula().setFormula("id ~ v1 + v2") val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) + MLTestingUtils.checkCopy(model) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) val expected = Seq( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org