Repository: spark Updated Branches: refs/heads/master e557c53c5 -> e00cac989
[SPARK-25959][ML] GBTClassifier picks wrong impurity stats on loading ## What changes were proposed in this pull request? Our `GBTClassifier` supports only `variance` impurity. But unfortunately, its `impurity` param by default contains the value `gini`: it is not even modifiable by the user and it differs from the actual impurity used, which is `variance`. This issue does not limit to a wrong value returned for it if the user queries by `getImpurity`, but it also affect the load of a saved model, as its `impurityStats` are created as `gini` (since this is the value stored for the model impurity) which leads to wrong `featureImportances` in model loaded from saved ones. The PR changes the `impurity` param used to one which allows only the value `variance`. ## How was this patch tested? modified UT Closes #22986 from mgaido91/SPARK-25959. Authored-by: Marco Gaido <marcogaid...@gmail.com> Signed-off-by: Sean Owen <sean.o...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e00cac98 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e00cac98 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e00cac98 Branch: refs/heads/master Commit: e00cac989821aea238c7bf20b69068ef7cf2eef3 Parents: e557c53 Author: Marco Gaido <marcogaid...@gmail.com> Authored: Sat Nov 17 09:46:45 2018 -0600 Committer: Sean Owen <sean.o...@databricks.com> Committed: Sat Nov 17 09:46:45 2018 -0600 ---------------------------------------------------------------------- .../spark/ml/classification/GBTClassifier.scala | 4 +++- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../org/apache/spark/ml/tree/treeParams.scala | 19 ++++++++++--------- .../ml/classification/GBTClassifierSuite.scala | 1 + project/MimaExcludes.scala | 11 +++++++++++ 6 files changed, 27 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e00cac98/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 62cfa39..62c6bdb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -427,7 +427,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) - metadata.getAndSetParams(model) + // We ignore the impurity while loading models because in previous models it was wrongly + // set to gini (see SPARK-25959). + metadata.getAndSetParams(model, Some(List("impurity"))) model } } http://git-wip-us.apache.org/repos/asf/spark/blob/e00cac98/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 6fa6562..c9de85d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -145,7 +145,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S @Since("1.4.0") object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] { /** Accessor for supported impurities: variance */ - final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities @Since("2.0.0") override def load(path: String): DecisionTreeRegressor = super.load(path) http://git-wip-us.apache.org/repos/asf/spark/blob/e00cac98/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 82bf66f..66d57ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -146,7 +146,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{ /** Accessor for supported impurity settings: variance */ @Since("1.4.0") - final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ @Since("1.4.0") http://git-wip-us.apache.org/repos/asf/spark/blob/e00cac98/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 00157fe..f1e3836 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -258,11 +258,7 @@ private[ml] object TreeClassifierParams { private[ml] trait DecisionTreeClassifierParams extends DecisionTreeParams with TreeClassifierParams -/** - * Parameters for Decision Tree-based regression algorithms. - */ -private[ml] trait TreeRegressorParams extends Params { - +private[ml] trait HasVarianceImpurity extends Params { /** * Criterion used for information gain calculation (case-insensitive). * Supported: "variance". @@ -271,9 +267,9 @@ private[ml] trait TreeRegressorParams extends Params { */ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", + s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}", (value: String) => - TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) + HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "variance") @@ -299,12 +295,17 @@ private[ml] trait TreeRegressorParams extends Params { } } -private[ml] object TreeRegressorParams { +private[ml] object HasVarianceImpurity { // These options should be lowercase. final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase(Locale.ROOT)) } +/** + * Parameters for Decision Tree-based regression algorithms. + */ +private[ml] trait TreeRegressorParams extends HasVarianceImpurity + private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams with TreeRegressorParams with HasVarianceCol { @@ -538,7 +539,7 @@ private[ml] object GBTClassifierParams { Array("logistic").map(_.toLowerCase(Locale.ROOT)) } -private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { +private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity { /** * Loss function which GBT tries to minimize. (case-insensitive) http://git-wip-us.apache.org/repos/asf/spark/blob/e00cac98/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 3049776..cedbaf1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -448,6 +448,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { model2: GBTClassificationModel): Unit = { TreeTests.checkEqual(model, model2) assert(model.numFeatures === model2.numFeatures) + assert(model.featureImportances == model2.featureImportances) } val gbt = new GBTClassifier() http://git-wip-us.apache.org/repos/asf/spark/blob/e00cac98/project/MimaExcludes.scala ---------------------------------------------------------------------- diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b030b6c..a8d2b5d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,17 @@ object MimaExcludes { // Exclude rules for 3.0.x lazy val v30excludes = v24excludes ++ Seq( + // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), + // [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3 ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org