Repository: spark Updated Branches: refs/heads/master 7c1654e21 -> 252468a74
[SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes ## What changes were proposed in this pull request? API: ``` trait ClassificationNode extends Node def getLabelCount(label: Int): Double trait RegressionNode extends Node def getCount(): Double def getSum(): Double def getSquareSum(): Double // turn LeafNode to be trait trait LeafNode extends Node { def prediction: Double def impurity: Double ... } class ClassificationLeafNode extends ClassificationNode with LeafNode class RegressionLeafNode extends RegressionNode with LeafNode // turn InternalNode to be trait trait InternalNode extends Node{ def gain: Double def leftChild: Node def rightChild: Node def split: Split ... } class ClassificationInternalNode extends ClassificationNode with InternalNode override def leftChild: ClassificationNode override def rightChild: ClassificationNode class RegressionInternalNode extends RegressionNode with InternalNode override val leftChild: RegressionNode override val rightChild: RegressionNode class DecisionTreeClassificationModel override val rootNode: ClassificationNode class DecisionTreeRegressionModel override val rootNode: RegressionNode ``` Closes #17466 ## How was this patch tested? UT will be added soon. Author: WeichenXu <weichen...@databricks.com> Author: jkbradley <joseph.kurata.brad...@gmail.com> Closes #20786 from WeichenXu123/tree_stat_api_2. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/252468a7 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/252468a7 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/252468a7 Branch: refs/heads/master Commit: 252468a744b95082400ba9e8b2e3b3d9d50ab7fa Parents: 7c1654e Author: WeichenXu <weichen...@databricks.com> Authored: Mon Apr 9 12:18:07 2018 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Mon Apr 9 12:18:07 2018 -0700 ---------------------------------------------------------------------- .../classification/DecisionTreeClassifier.scala | 14 +- .../spark/ml/classification/GBTClassifier.scala | 6 +- .../classification/RandomForestClassifier.scala | 6 +- .../ml/regression/DecisionTreeRegressor.scala | 13 +- .../spark/ml/regression/GBTRegressor.scala | 6 +- .../ml/regression/RandomForestRegressor.scala | 6 +- .../scala/org/apache/spark/ml/tree/Node.scala | 247 +++++++++++++++---- .../spark/ml/tree/impl/RandomForest.scala | 10 +- .../org/apache/spark/ml/tree/treeModels.scala | 36 ++- .../DecisionTreeClassifierSuite.scala | 31 ++- .../ml/classification/GBTClassifierSuite.scala | 4 +- .../RandomForestClassifierSuite.scala | 5 +- .../regression/DecisionTreeRegressorSuite.scala | 14 ++ .../spark/ml/tree/impl/RandomForestSuite.scala | 22 +- .../apache/spark/ml/tree/impl/TreeTests.scala | 12 +- project/MimaExcludes.scala | 9 +- 16 files changed, 333 insertions(+), 108 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 65cce69..771cd4f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -165,7 +165,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi @Since("1.4.0") class DecisionTreeClassificationModel private[ml] ( @Since("1.4.0")override val uid: String, - @Since("1.4.0")override val rootNode: Node, + @Since("1.4.0")override val rootNode: ClassificationNode, @Since("1.6.0")override val numFeatures: Int, @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] @@ -178,7 +178,7 @@ class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = + private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) override def predict(features: Vector): Double = { @@ -276,8 +276,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession) - val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) + val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) + val model = new DecisionTreeClassificationModel(metadata.uid, + root.asInstanceOf[ClassificationNode], numFeatures, numClasses) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -292,9 +293,10 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") // Can't infer number of features from old model, so default to -1 - new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) + new DecisionTreeClassificationModel(uid, + rootNode.asInstanceOf[ClassificationNode], numFeatures, -1) } } http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index cd44489..c025510 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -371,14 +371,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override def load(path: String): GBTClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] val numTrees = (metadata.metadata \ numTreesKey).extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 78a4972..bb972e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -310,15 +310,15 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeClassificationModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses) + val tree = new DecisionTreeClassificationModel(treeMetadata.uid, + root.asInstanceOf[ClassificationNode], numFeatures, numClasses) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index ad154fc..5cef5c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor @Since("1.4.0") class DecisionTreeRegressionModel private[ml] ( override val uid: String, - override val rootNode: Node, + override val rootNode: RegressionNode, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable { @@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numFeatures: Int) = + private[ml] def this(rootNode: RegressionNode, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) override def predict(features: Vector): Double = { @@ -279,8 +279,9 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession) - val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) + val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) + val model = new DecisionTreeRegressionModel(metadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -295,8 +296,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode require(oldModel.algo == OldAlgo.Regression, s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") - new DecisionTreeRegressionModel(uid, rootNode, numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures) } } http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 6569ff2..834aaa0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -302,15 +302,15 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { override def load(path: String): GBTRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2d59446..7f77398 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -269,13 +269,13 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d30be45..0242bc76 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,14 +17,16 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, + Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. */ -sealed abstract class Node extends Serializable { +sealed trait Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree // code into the new API and deprecate the old API. SPARK-3727 @@ -84,35 +86,86 @@ private[ml] object Node { /** * Create a new Node from the old Node format, recursively creating child nodes as needed. */ - def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { + def fromOld( + oldNode: OldNode, + categoricalFeatures: Map[Int, Int], + isClassification: Boolean): Node = { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - new LeafNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, impurityStats = null) + if (isClassification) { + new ClassificationLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } else { + new RegressionLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain } else { 0.0 } - new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, - gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), - rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), - split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + if (isClassification) { + new ClassificationInternalNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, gain = gain, + leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true) + .asInstanceOf[ClassificationNode], + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, true) + .asInstanceOf[ClassificationNode], + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } else { + new RegressionInternalNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, gain = gain, + leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false) + .asInstanceOf[RegressionNode], + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, false) + .asInstanceOf[RegressionNode], + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } } } } -/** - * Decision tree leaf node. - * @param prediction Prediction this node makes - * @param impurity Impurity measure at this node (for training data) - */ -class LeafNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override private[ml] val impurityStats: ImpurityCalculator) extends Node { +@Since("2.4.0") +sealed trait ClassificationNode extends Node { + + /** + * Get count of training examples for specified label in this node + * @param label label number in the range [0, numClasses) + */ + @Since("2.4.0") + def getLabelCount(label: Int): Double = { + require(label >= 0 && label < impurityStats.stats.length, + "label should be in the range between 0 (inclusive) " + + s"and ${impurityStats.stats.length} (exclusive).") + impurityStats.stats(label) + } +} + +@Since("2.4.0") +sealed trait RegressionNode extends Node { + + /** Number of training data points in this node */ + @Since("2.4.0") + def getCount: Double = impurityStats.stats(0) + + /** Sum over training data points of the labels in this node */ + @Since("2.4.0") + def getSum: Double = impurityStats.stats(1) + + /** Sum over training data points of the square of the labels in this node */ + @Since("2.4.0") + def getSumOfSquares: Double = impurityStats.stats(2) +} + +@Since("2.4.0") +sealed trait LeafNode extends Node { + + /** Prediction this node makes. */ + def prediction: Double + + def impurity: Double override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" @@ -135,32 +188,58 @@ class LeafNode private[ml] ( override private[ml] def maxSplitFeatureIndex(): Int = -1 +} + +/** + * Decision tree leaf node for classification. + */ +@Since("2.4.0") +class ClassificationLeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override private[ml] val impurityStats: ImpurityCalculator) + extends ClassificationNode with LeafNode { + override private[tree] def deepCopy(): Node = { - new LeafNode(prediction, impurity, impurityStats) + new ClassificationLeafNode(prediction, impurity, impurityStats) } } /** - * Internal Decision Tree node. - * @param prediction Prediction this node would make if it were a leaf node - * @param impurity Impurity measure at this node (for training data) - * @param gain Information gain value. Values less than 0 indicate missing values; - * this quirk will be removed with future updates. - * @param leftChild Left-hand child node - * @param rightChild Right-hand child node - * @param split Information about the test used to split to the left or right child. + * Decision tree leaf node for regression. */ -class InternalNode private[ml] ( +@Since("2.4.0") +class RegressionLeafNode private[ml] ( override val prediction: Double, override val impurity: Double, - val gain: Double, - val leftChild: Node, - val rightChild: Node, - val split: Split, - override private[ml] val impurityStats: ImpurityCalculator) extends Node { + override private[ml] val impurityStats: ImpurityCalculator) + extends RegressionNode with LeafNode { - // Note to developers: The constructor argument impurityStats should be reconsidered before we - // make the constructor public. We may be able to improve the representation. + override private[tree] def deepCopy(): Node = { + new RegressionLeafNode(prediction, impurity, impurityStats) + } +} + +/** + * Internal Decision Tree node. + */ +@Since("2.4.0") +sealed trait InternalNode extends Node { + + /** + * Information gain value. Values less than 0 indicate missing values; + * this quirk will be removed with future updates. + */ + def gain: Double + + /** Left-hand child node */ + def leftChild: Node + + /** Right-hand child node */ + def rightChild: Node + + /** Information about the test used to split to the left or right child. */ + def split: Split override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" @@ -205,11 +284,6 @@ class InternalNode private[ml] ( math.max(split.featureIndex, math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) } - - override private[tree] def deepCopy(): Node = { - new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(), - split, impurityStats) - } } private object InternalNode { @@ -241,6 +315,57 @@ private object InternalNode { } /** + * Internal Decision Tree node for regression. + */ +@Since("2.4.0") +class ClassificationInternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override val gain: Double, + override val leftChild: ClassificationNode, + override val rightChild: ClassificationNode, + override val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) + extends ClassificationNode with InternalNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override private[tree] def deepCopy(): Node = { + new ClassificationInternalNode(prediction, impurity, gain, + leftChild.deepCopy().asInstanceOf[ClassificationNode], + rightChild.deepCopy().asInstanceOf[ClassificationNode], + split, impurityStats) + } +} + +/** + * Internal Decision Tree node for regression. + */ +@Since("2.4.0") +class RegressionInternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override val gain: Double, + override val leftChild: RegressionNode, + override val rightChild: RegressionNode, + override val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) + extends RegressionNode with InternalNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override private[tree] def deepCopy(): Node = { + new RegressionInternalNode(prediction, impurity, gain, + leftChild.deepCopy().asInstanceOf[RegressionNode], + rightChild.deepCopy().asInstanceOf[RegressionNode], + split, impurityStats) + } +} + + +/** * Version of a node used in learning. This uses vars so that we can modify nodes as we split the * tree by adding children, etc. * @@ -265,30 +390,52 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { - def toNode: Node = toNode(prune = true) + def toNode(isClassification: Boolean): Node = toNode(isClassification, prune = true) + + def toClassificationNode(prune: Boolean = true): ClassificationNode = { + toNode(true, prune).asInstanceOf[ClassificationNode] + } + + def toRegressionNode(prune: Boolean = true): RegressionNode = { + toNode(false, prune).asInstanceOf[RegressionNode] + } /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode(prune: Boolean = true): Node = { + def toNode(isClassification: Boolean, prune: Boolean): Node = { if (!leftChild.isEmpty || !rightChild.isEmpty) { assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { + (leftChild.get.toNode(isClassification, prune), + rightChild.get.toNode(isClassification, prune)) match { case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => - new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + if (isClassification) { + new ClassificationLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + } else { + new RegressionLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + } case (l, r) => - new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - l, r, split.get, stats.impurityCalculator) + if (isClassification) { + new ClassificationInternalNode(stats.impurityCalculator.predict, stats.impurity, + stats.gain, l.asInstanceOf[ClassificationNode], r.asInstanceOf[ClassificationNode], + split.get, stats.impurityCalculator) + } else { + new RegressionInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode], + split.get, stats.impurityCalculator) + } } } else { - if (stats.valid) { - new LeafNode(stats.impurityCalculator.predict, stats.impurity, + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + val impurity = if (stats.valid) stats.impurity else -1.0 + if (isClassification) { + new ClassificationLeafNode(stats.impurityCalculator.predict, impurity, stats.impurityCalculator) } else { - // Here we want to keep same behavior with the old mllib.DecisionTreeModel - new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) + new RegressionLeafNode(stats.impurityCalculator.predict, impurity, + stats.impurityCalculator) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 16f32d7..056a94b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -224,23 +224,23 @@ private[spark] object RandomForest extends Logging { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, - strategy.getNumClasses) + new DecisionTreeClassificationModel(uid, rootNode.toClassificationNode(prune), + numFeatures, strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toRegressionNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, + new DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), numFeatures, strategy.getNumClasses) } } else { topNodes.map(rootNode => - new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) + new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), numFeatures)) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 4aa4c36..f027b14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -219,8 +219,10 @@ private[ml] object TreeEnsembleModel { importances.changeValue(feature, scaledGain, _ + scaledGain) computeFeatureImportance(n.leftChild, importances) computeFeatureImportance(n.rightChild, importances) - case n: LeafNode => + case _: LeafNode => // do nothing + case _ => + throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -317,6 +319,8 @@ private[ml] object DecisionTreeModelReadWrite { (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), id) + case _ => + throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -327,7 +331,7 @@ private[ml] object DecisionTreeModelReadWrite { def loadTreeNodes( path: String, metadata: DefaultParamsReader.Metadata, - sparkSession: SparkSession): Node = { + sparkSession: SparkSession, isClassification: Boolean): Node = { import sparkSession.implicits._ implicit val format = DefaultFormats @@ -339,7 +343,7 @@ private[ml] object DecisionTreeModelReadWrite { val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).as[NodeData] - buildTreeFromNodes(data.collect(), impurityType) + buildTreeFromNodes(data.collect(), impurityType, isClassification) } /** @@ -348,7 +352,8 @@ private[ml] object DecisionTreeModelReadWrite { * @param impurityType Impurity type for this tree * @return Root node of reconstructed tree */ - def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { + def buildTreeFromNodes(data: Array[NodeData], impurityType: String, + isClassification: Boolean): Node = { // Load all nodes, sorted by ID. val nodes = data.sortBy(_.id) // Sanity checks; could remove @@ -364,10 +369,21 @@ private[ml] object DecisionTreeModelReadWrite { val node = if (n.leftChild != -1) { val leftChild = finalNodes(n.leftChild) val rightChild = finalNodes(n.rightChild) - new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild, - n.split.getSplit, impurityStats) + if (isClassification) { + new ClassificationInternalNode(n.prediction, n.impurity, n.gain, + leftChild.asInstanceOf[ClassificationNode], rightChild.asInstanceOf[ClassificationNode], + n.split.getSplit, impurityStats) + } else { + new RegressionInternalNode(n.prediction, n.impurity, n.gain, + leftChild.asInstanceOf[RegressionNode], rightChild.asInstanceOf[RegressionNode], + n.split.getSplit, impurityStats) + } } else { - new LeafNode(n.prediction, n.impurity, impurityStats) + if (isClassification) { + new ClassificationLeafNode(n.prediction, n.impurity, impurityStats) + } else { + new RegressionLeafNode(n.prediction, n.impurity, impurityStats) + } } finalNodes(n.id) = node } @@ -421,7 +437,8 @@ private[ml] object EnsembleModelReadWrite { path: String, sql: SparkSession, className: String, - treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { + treeClassName: String, + isClassification: Boolean): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) @@ -449,7 +466,8 @@ private[ml] object EnsembleModelReadWrite { val rootNodesRDD: RDD[(Int, Node)] = nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes( + nodeData.toArray, impurityType, isClassification) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 2930f49..d3dbb4e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.ClassificationLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -61,7 +61,8 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2) + val model = new DecisionTreeClassificationModel("dtc", + new ClassificationLeafNode(0.0, 0.0, null), 1, 2) ParamsSuite.checkParams(model) } @@ -375,6 +376,32 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(model) } + + test("label/impurity stats") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2) + val dt1 = new DecisionTreeClassifier() + .setImpurity("entropy") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model1 = dt1.fit(df) + + val rootNode1 = model1.rootNode + assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === Array(2.0, 1.0)) + + val dt2 = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model2 = dt2.fit(df) + + val rootNode2 = model2.rootNode + assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === Array(2.0, 1.0)) + } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 5779606..f0ee549 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.RegressionLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -69,7 +69,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), + Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 0.0, null), 1)), Array(1.0), 1, 2) ParamsSuite.checkParams(model) } http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ba4a9cf..3062aa9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.ClassificationLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -71,7 +71,8 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2) + Array(new DecisionTreeClassificationModel("dtc", + new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2) ParamsSuite.checkParams(model) } http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 29a4383..9ae2733 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -191,6 +191,20 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } + + test("label/impurity stats") { + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val dtr = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8) + val model = dtr.fit(df) + val statInfo = model.rootNode + + assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0 + && statInfo.getSumOfSquares == 600.0) + } } private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 743dacf..4dbbd75 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -340,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode.prediction === 0.0) - assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) + assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -382,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode.prediction === 0.0) - assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) + assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -582,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { left right */ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) - val left = new LeafNode(0.0, leftImp.calculate(), leftImp) + val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp) val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) - val right = new LeafNode(2.0, rightImp.calculate(), rightImp) + val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp) - val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5), true) val parentImp = parent.impurityStats val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) - val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) + val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp) - val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0), true) val grandImp = grandParent.impurityStats // Test feature importance computed at different subtrees. @@ -618,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Forest consisting of (full tree) + (internal node with 2 leafs) val trees = Array(parent, grandParent).map { root => - new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) - .asInstanceOf[DecisionTreeModel] + new DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode], + numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel] } val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b3..3f03d90 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite { * @param split Split for parent node * @return Parent node with children attached */ - def buildParentNode(left: Node, right: Node, split: Split): Node = { + def buildParentNode(left: Node, right: Node, split: Split, isClassification: Boolean): Node = { val leftImp = left.impurityStats val rightImp = right.impurityStats val parentImp = leftImp.copy.add(rightImp) @@ -168,7 +168,15 @@ private[ml] object TreeTests extends SparkFunSuite { val gain = parentImp.calculate() - (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) val pred = parentImp.predict - new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) + if (isClassification) { + new ClassificationInternalNode(pred, parentImp.calculate(), gain, + left.asInstanceOf[ClassificationNode], right.asInstanceOf[ClassificationNode], + split, parentImp) + } else { + new RegressionInternalNode(pred, parentImp.calculate(), gain, + left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode], + split, parentImp) + } } /** http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/project/MimaExcludes.scala ---------------------------------------------------------------------- diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1b6d1de..b37b4d5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -55,7 +55,14 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), + + // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this") ) // Exclude rules for 2.3.x --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org