Repository: spark
Updated Branches:
  refs/heads/master 15d392688 -> 23ce0d1e9


[SPARK-18276][ML] ML models should copy the training summary and set parent

## What changes were proposed in this pull request?

Only some of the models which contain a training summary currently set the 
summaries in the copy method. Linear/Logistic regression do, GLR, GMM, KM, and 
BKM do not. Additionally, these copy methods did not set the parent pointer of 
the copied model. This patch modifies the copy methods of the four models 
mentioned above to copy the training summary and set the parent.

## How was this patch tested?

Add unit tests in Linear/Logistic/GeneralizedLinear regression and 
GaussianMixture/KMeans/BisectingKMeans to check the parent pointer of the 
copied model and check that the copied model has a summary.

Author: sethah <seth.hendrickso...@gmail.com>

Closes #15773 from sethah/SPARK-18276.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/23ce0d1e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/23ce0d1e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/23ce0d1e

Branch: refs/heads/master
Commit: 23ce0d1e91076d90c1a87d698a94d283d08cf899
Parents: 15d3926
Author: sethah <seth.hendrickso...@gmail.com>
Authored: Sat Nov 5 22:38:07 2016 -0700
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Sat Nov 5 22:38:07 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/clustering/BisectingKMeans.scala |  5 +++--
 .../org/apache/spark/ml/clustering/GaussianMixture.scala |  5 +++--
 .../scala/org/apache/spark/ml/clustering/KMeans.scala    |  5 +++--
 .../ml/regression/GeneralizedLinearRegression.scala      |  6 ++++--
 .../apache/spark/ml/tuning/TrainValidationSplit.scala    |  2 +-
 .../ml/classification/LogisticRegressionSuite.scala      | 11 +++++++----
 .../spark/ml/clustering/BisectingKMeansSuite.scala       | 10 +++++++++-
 .../spark/ml/clustering/GaussianMixtureSuite.scala       | 10 +++++++++-
 .../org/apache/spark/ml/clustering/KMeansSuite.scala     | 10 +++++++++-
 .../ml/regression/GeneralizedLinearRegressionSuite.scala |  5 ++++-
 .../spark/ml/regression/LinearRegressionSuite.scala      |  5 ++++-
 .../spark/ml/tuning/TrainValidationSplitSuite.scala      |  8 ++++++--
 12 files changed, 62 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/23ce0d1e/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 2718dd9..f8a606d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -94,8 +94,9 @@ class BisectingKMeansModel private[ml] (
 
   @Since("2.0.0")
   override def copy(extra: ParamMap): BisectingKMeansModel = {
-    val copied = new BisectingKMeansModel(uid, parentModel)
-    copyValues(copied, extra)
+    val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra)
+    if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
+    copied.setParent(this.parent)
   }
 
   @Since("2.0.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/23ce0d1e/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 8fac63f..a0bd66e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -89,8 +89,9 @@ class GaussianMixtureModel private[ml] (
 
   @Since("2.0.0")
   override def copy(extra: ParamMap): GaussianMixtureModel = {
-    val copied = new GaussianMixtureModel(uid, weights, gaussians)
-    copyValues(copied, extra).setParent(this.parent)
+    val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), 
extra)
+    if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
+    copied.setParent(this.parent)
   }
 
   @Since("2.0.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/23ce0d1e/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 85bb8c9..a0d481b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -108,8 +108,9 @@ class KMeansModel private[ml] (
 
   @Since("1.5.0")
   override def copy(extra: ParamMap): KMeansModel = {
-    val copied = new KMeansModel(uid, parentModel)
-    copyValues(copied, extra)
+    val copied = copyValues(new KMeansModel(uid, parentModel), extra)
+    if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
+    copied.setParent(this.parent)
   }
 
   /** @group setParam */

http://git-wip-us.apache.org/repos/asf/spark/blob/23ce0d1e/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 8656ecf..1938e8e 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -776,8 +776,10 @@ class GeneralizedLinearRegressionModel private[ml] (
 
   @Since("2.0.0")
   override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
-    copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, 
intercept), extra)
-      .setParent(parent)
+    val copied = copyValues(new GeneralizedLinearRegressionModel(uid, 
coefficients, intercept),
+      extra)
+    if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
+    copied.setParent(parent)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/23ce0d1e/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 0fdba1c..5d1a39f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -221,7 +221,7 @@ class TrainValidationSplitModel private[ml] (
       uid,
       bestModel.copy(extra).asInstanceOf[Model[_]],
       validationMetrics.clone())
-    copyValues(copied, extra)
+    copyValues(copied, extra).setParent(parent)
   }
 
   @Since("2.0.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/23ce0d1e/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 8771fd2..2877285 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.attribute.NominalAttribute
 import org.apache.spark.ml.classification.LogisticRegressionSuite._
 import org.apache.spark.ml.feature.{Instance, LabeledPoint}
 import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, 
SparseVector, Vector, Vectors}
-import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -141,6 +141,12 @@ class LogisticRegressionSuite
     assert(model.getProbabilityCol === "probability")
     assert(model.intercept !== 0.0)
     assert(model.hasParent)
+
+    // copied model must have the same parent.
+    MLTestingUtils.checkCopy(model)
+    assert(model.hasSummary)
+    val copiedModel = model.copy(ParamMap.empty)
+    assert(copiedModel.hasSummary)
   }
 
   test("empty probabilityCol") {
@@ -251,9 +257,6 @@ class LogisticRegressionSuite
     mlr.setFitIntercept(false)
     val mlrModel = mlr.fit(smallMultinomialDataset)
     assert(mlrModel.interceptVector === Vectors.sparse(3, Seq()))
-
-    // copied model must have the same parent.
-    MLTestingUtils.checkCopy(model)
   }
 
   test("logistic regression with setters") {

http://git-wip-us.apache.org/repos/asf/spark/blob/23ce0d1e/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index f2368a9..49797d9 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -18,7 +18,8 @@
 package org.apache.spark.ml.clustering
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Dataset
 
@@ -41,6 +42,13 @@ class BisectingKMeansSuite
     assert(bkm.getPredictionCol === "prediction")
     assert(bkm.getMaxIter === 20)
     assert(bkm.getMinDivisibleClusterSize === 1.0)
+    val model = bkm.setMaxIter(1).fit(dataset)
+
+    // copied model must have the same parent
+    MLTestingUtils.checkCopy(model)
+    assert(model.hasSummary)
+    val copiedModel = model.copy(ParamMap.empty)
+    assert(copiedModel.hasSummary)
   }
 
   test("setter/getter") {

http://git-wip-us.apache.org/repos/asf/spark/blob/23ce0d1e/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 003fa6a..7165b63 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -18,7 +18,8 @@
 package org.apache.spark.ml.clustering
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Dataset
 
@@ -43,6 +44,13 @@ class GaussianMixtureSuite extends SparkFunSuite with 
MLlibTestSparkContext
     assert(gm.getPredictionCol === "prediction")
     assert(gm.getMaxIter === 100)
     assert(gm.getTol === 0.01)
+    val model = gm.setMaxIter(1).fit(dataset)
+
+    // copied model must have the same parent
+    MLTestingUtils.checkCopy(model)
+    assert(model.hasSummary)
+    val copiedModel = model.copy(ParamMap.empty)
+    assert(copiedModel.hasSummary)
   }
 
   test("set parameters") {

http://git-wip-us.apache.org/repos/asf/spark/blob/23ce0d1e/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index ca39265..7397255 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -19,7 +19,8 @@ package org.apache.spark.ml.clustering
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
@@ -47,6 +48,13 @@ class KMeansSuite extends SparkFunSuite with 
MLlibTestSparkContext with DefaultR
     assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
     assert(kmeans.getInitSteps === 2)
     assert(kmeans.getTol === 1e-4)
+    val model = kmeans.setMaxIter(1).fit(dataset)
+
+    // copied model must have the same parent
+    MLTestingUtils.checkCopy(model)
+    assert(model.hasSummary)
+    val copiedModel = model.copy(ParamMap.empty)
+    assert(copiedModel.hasSummary)
   }
 
   test("set parameters") {

http://git-wip-us.apache.org/repos/asf/spark/blob/23ce0d1e/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index ac1ef5f..111bc97 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -24,7 +24,7 @@ import 
org.apache.spark.ml.classification.LogisticRegressionSuite._
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors}
-import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.random._
@@ -183,6 +183,9 @@ class GeneralizedLinearRegressionSuite
 
     // copied model must have the same parent.
     MLTestingUtils.checkCopy(model)
+    assert(model.hasSummary)
+    val copiedModel = model.copy(ParamMap.empty)
+    assert(copiedModel.hasSummary)
 
     assert(model.getFeaturesCol === "features")
     assert(model.getPredictionCol === "prediction")

http://git-wip-us.apache.org/repos/asf/spark/blob/23ce0d1e/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index c0e8afb..df97d0b 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
-import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
@@ -143,6 +143,9 @@ class LinearRegressionSuite
 
     // copied model must have the same parent.
     MLTestingUtils.checkCopy(model)
+    assert(model.hasSummary)
+    val copiedModel = model.copy(ParamMap.empty)
+    assert(copiedModel.hasSummary)
 
     model.transform(datasetWithDenseFeature)
       .select("label", "prediction")

http://git-wip-us.apache.org/repos/asf/spark/blob/23ce0d1e/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 87100ae..4463a9b 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -22,11 +22,11 @@ import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.classification.{LogisticRegression, 
LogisticRegressionModel}
 import 
org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
 import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, 
Evaluator, RegressionEvaluator}
-import org.apache.spark.ml.linalg.{DenseMatrix, Vectors}
+import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.Dataset
 import org.apache.spark.sql.types.StructType
@@ -78,6 +78,10 @@ class TrainValidationSplitSuite
       .setTrainRatio(0.5)
       .setSeed(42L)
     val cvModel = cv.fit(dataset)
+
+    // copied model must have the same paren.
+    MLTestingUtils.checkCopy(cvModel)
+
     val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
     assert(parent.getRegParam === 0.001)
     assert(parent.getMaxIter === 10)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to