Repository: spark
Updated Branches:
  refs/heads/master e557c53c5 -> e00cac989


[SPARK-25959][ML] GBTClassifier picks wrong impurity stats on loading

## What changes were proposed in this pull request?

Our `GBTClassifier` supports only `variance` impurity. But unfortunately, its 
`impurity` param by default contains the value `gini`: it is not even 
modifiable by the user and it differs from the actual impurity used, which is 
`variance`. This issue does not limit to a wrong value returned for it if the 
user queries by `getImpurity`, but it also affect the load of a saved model, as 
its `impurityStats` are created as `gini` (since this is the value stored for 
the model impurity) which leads to wrong `featureImportances` in model loaded 
from saved ones.

The PR changes the `impurity` param used to one which allows only the value 
`variance`.

## How was this patch tested?

modified UT

Closes #22986 from mgaido91/SPARK-25959.

Authored-by: Marco Gaido <marcogaid...@gmail.com>
Signed-off-by: Sean Owen <sean.o...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e00cac98
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e00cac98
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e00cac98

Branch: refs/heads/master
Commit: e00cac989821aea238c7bf20b69068ef7cf2eef3
Parents: e557c53
Author: Marco Gaido <marcogaid...@gmail.com>
Authored: Sat Nov 17 09:46:45 2018 -0600
Committer: Sean Owen <sean.o...@databricks.com>
Committed: Sat Nov 17 09:46:45 2018 -0600

----------------------------------------------------------------------
 .../spark/ml/classification/GBTClassifier.scala  |  4 +++-
 .../ml/regression/DecisionTreeRegressor.scala    |  2 +-
 .../ml/regression/RandomForestRegressor.scala    |  2 +-
 .../org/apache/spark/ml/tree/treeParams.scala    | 19 ++++++++++---------
 .../ml/classification/GBTClassifierSuite.scala   |  1 +
 project/MimaExcludes.scala                       | 11 +++++++++++
 6 files changed, 27 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e00cac98/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 62cfa39..62c6bdb 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -427,7 +427,9 @@ object GBTClassificationModel extends 
MLReadable[GBTClassificationModel] {
         s" trees based on metadata but found ${trees.length} trees.")
       val model = new GBTClassificationModel(metadata.uid,
         trees, treeWeights, numFeatures)
-      metadata.getAndSetParams(model)
+      // We ignore the impurity while loading models because in previous 
models it was wrongly
+      // set to gini (see SPARK-25959).
+      metadata.getAndSetParams(model, Some(List("impurity")))
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e00cac98/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 6fa6562..c9de85d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -145,7 +145,7 @@ class DecisionTreeRegressor @Since("1.4.0") 
(@Since("1.4.0") override val uid: S
 @Since("1.4.0")
 object DecisionTreeRegressor extends 
DefaultParamsReadable[DecisionTreeRegressor] {
   /** Accessor for supported impurities: variance */
-  final val supportedImpurities: Array[String] = 
TreeRegressorParams.supportedImpurities
+  final val supportedImpurities: Array[String] = 
HasVarianceImpurity.supportedImpurities
 
   @Since("2.0.0")
   override def load(path: String): DecisionTreeRegressor = super.load(path)

http://git-wip-us.apache.org/repos/asf/spark/blob/e00cac98/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 82bf66f..66d57ad 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -146,7 +146,7 @@ class RandomForestRegressor @Since("1.4.0") 
(@Since("1.4.0") override val uid: S
 object RandomForestRegressor extends 
DefaultParamsReadable[RandomForestRegressor]{
   /** Accessor for supported impurity settings: variance */
   @Since("1.4.0")
-  final val supportedImpurities: Array[String] = 
TreeRegressorParams.supportedImpurities
+  final val supportedImpurities: Array[String] = 
HasVarianceImpurity.supportedImpurities
 
   /** Accessor for supported featureSubsetStrategy settings: auto, all, 
onethird, sqrt, log2 */
   @Since("1.4.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/e00cac98/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 00157fe..f1e3836 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -258,11 +258,7 @@ private[ml] object TreeClassifierParams {
 private[ml] trait DecisionTreeClassifierParams
   extends DecisionTreeParams with TreeClassifierParams
 
-/**
- * Parameters for Decision Tree-based regression algorithms.
- */
-private[ml] trait TreeRegressorParams extends Params {
-
+private[ml] trait HasVarianceImpurity extends Params {
   /**
    * Criterion used for information gain calculation (case-insensitive).
    * Supported: "variance".
@@ -271,9 +267,9 @@ private[ml] trait TreeRegressorParams extends Params {
    */
   final val impurity: Param[String] = new Param[String](this, "impurity", 
"Criterion used for" +
     " information gain calculation (case-insensitive). Supported options:" +
-    s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
+    s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}",
     (value: String) =>
-      
TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
+      
HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
 
   setDefault(impurity -> "variance")
 
@@ -299,12 +295,17 @@ private[ml] trait TreeRegressorParams extends Params {
   }
 }
 
-private[ml] object TreeRegressorParams {
+private[ml] object HasVarianceImpurity {
   // These options should be lowercase.
   final val supportedImpurities: Array[String] =
     Array("variance").map(_.toLowerCase(Locale.ROOT))
 }
 
+/**
+ * Parameters for Decision Tree-based regression algorithms.
+ */
+private[ml] trait TreeRegressorParams extends HasVarianceImpurity
+
 private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
   with TreeRegressorParams with HasVarianceCol {
 
@@ -538,7 +539,7 @@ private[ml] object GBTClassifierParams {
     Array("logistic").map(_.toLowerCase(Locale.ROOT))
 }
 
-private[ml] trait GBTClassifierParams extends GBTParams with 
TreeClassifierParams {
+private[ml] trait GBTClassifierParams extends GBTParams with 
HasVarianceImpurity {
 
   /**
    * Loss function which GBT tries to minimize. (case-insensitive)

http://git-wip-us.apache.org/repos/asf/spark/blob/e00cac98/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 3049776..cedbaf1 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -448,6 +448,7 @@ class GBTClassifierSuite extends MLTest with 
DefaultReadWriteTest {
         model2: GBTClassificationModel): Unit = {
       TreeTests.checkEqual(model, model2)
       assert(model.numFeatures === model2.numFeatures)
+      assert(model.featureImportances == model2.featureImportances)
     }
 
     val gbt = new GBTClassifier()

http://git-wip-us.apache.org/repos/asf/spark/blob/e00cac98/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index b030b6c..a8d2b5d 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,17 @@ object MimaExcludes {
 
   // Exclude rules for 3.0.x
   lazy val v30excludes = v24excludes ++ Seq(
+    // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading
+    
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"),
+    
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
+    
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
+    
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
+    
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
+    
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
+    
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"),
+    
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"),
+    
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"),
+
     // [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"),
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to