Repository: spark Updated Branches: refs/heads/branch-1.4 8567d29ef -> 24cb323e7
[SPARK-7047] [ML] ml.Model optional parent support Made Model.parent transient. Added Model.hasParent to test for null parent CC: mengxr Author: Joseph K. Bradley <jos...@databricks.com> Closes #5914 from jkbradley/parent-optional and squashes the following commits: d501774 [Joseph K. Bradley] Made Model.parent transient. Added Model.hasParent to test for null parent (cherry picked from commit fb90273212dc7241c9a0c3446e25e0e0b9377750) Signed-off-by: Xiangrui Meng <m...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/24cb323e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/24cb323e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/24cb323e Branch: refs/heads/branch-1.4 Commit: 24cb323e767a342496cf24e0d06398b5af38ac80 Parents: 8567d29 Author: Joseph K. Bradley <jos...@databricks.com> Authored: Tue May 19 10:55:21 2015 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Tue May 19 10:55:32 2015 -0700 ---------------------------------------------------------------------- mllib/src/main/scala/org/apache/spark/ml/Model.scala | 5 ++++- .../spark/ml/classification/LogisticRegressionSuite.scala | 1 + .../spark/ml/classification/RandomForestClassifierSuite.scala | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/24cb323e/mllib/src/main/scala/org/apache/spark/ml/Model.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 7fd5153..70e7495 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -32,7 +32,7 @@ abstract class Model[M <: Model[M]] extends Transformer { * The parent estimator that produced this model. * Note: For ensembles' component Models, this value can be null. */ - var parent: Estimator[M] = _ + @transient var parent: Estimator[M] = _ /** * Sets the parent of this model (Java API). @@ -42,6 +42,9 @@ abstract class Model[M <: Model[M]] extends Transformer { this.asInstanceOf[M] } + /** Indicates whether this [[Model]] has a corresponding parent. */ + def hasParent: Boolean = parent != null + override def copy(extra: ParamMap): M = { // The default implementation of Params.copy doesn't work for models. throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)") http://git-wip-us.apache.org/repos/asf/spark/blob/24cb323e/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 4376524..97f9749 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -83,6 +83,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.getRawPredictionCol === "rawPrediction") assert(model.getProbabilityCol === "probability") assert(model.intercept !== 0.0) + assert(model.hasParent) } test("logistic regression doesn't fit intercept when fitIntercept is off") { http://git-wip-us.apache.org/repos/asf/spark/blob/24cb323e/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 08f86fa..cdbbaca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -162,5 +162,7 @@ private object RandomForestClassifierSuite { val oldModelAsNew = RandomForestClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) + assert(newModel.hasParent) + assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org