Repository: spark
Updated Branches:
  refs/heads/master 7c1654e21 -> 252468a74


[SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes

## What changes were proposed in this pull request?

API:
```
trait ClassificationNode extends Node
  def getLabelCount(label: Int): Double

trait RegressionNode extends Node
  def getCount(): Double
  def getSum(): Double
  def getSquareSum(): Double

// turn LeafNode to be trait
trait LeafNode extends Node {
  def prediction: Double
  def impurity: Double
  ...
}

class ClassificationLeafNode extends ClassificationNode with LeafNode

class RegressionLeafNode extends RegressionNode with LeafNode

// turn InternalNode to be trait
trait InternalNode extends Node{
  def gain: Double
  def leftChild: Node
  def rightChild: Node
  def split: Split
  ...
}

class ClassificationInternalNode extends ClassificationNode with InternalNode
  override def leftChild: ClassificationNode
  override def rightChild: ClassificationNode

class RegressionInternalNode extends RegressionNode with InternalNode
  override val leftChild: RegressionNode
  override val rightChild: RegressionNode

class DecisionTreeClassificationModel
  override val rootNode: ClassificationNode

class DecisionTreeRegressionModel
  override val rootNode: RegressionNode
```
Closes #17466

## How was this patch tested?

UT will be added soon.

Author: WeichenXu <weichen...@databricks.com>
Author: jkbradley <joseph.kurata.brad...@gmail.com>

Closes #20786 from WeichenXu123/tree_stat_api_2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/252468a7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/252468a7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/252468a7

Branch: refs/heads/master
Commit: 252468a744b95082400ba9e8b2e3b3d9d50ab7fa
Parents: 7c1654e
Author: WeichenXu <weichen...@databricks.com>
Authored: Mon Apr 9 12:18:07 2018 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Mon Apr 9 12:18:07 2018 -0700

----------------------------------------------------------------------
 .../classification/DecisionTreeClassifier.scala |  14 +-
 .../spark/ml/classification/GBTClassifier.scala |   6 +-
 .../classification/RandomForestClassifier.scala |   6 +-
 .../ml/regression/DecisionTreeRegressor.scala   |  13 +-
 .../spark/ml/regression/GBTRegressor.scala      |   6 +-
 .../ml/regression/RandomForestRegressor.scala   |   6 +-
 .../scala/org/apache/spark/ml/tree/Node.scala   | 247 +++++++++++++++----
 .../spark/ml/tree/impl/RandomForest.scala       |  10 +-
 .../org/apache/spark/ml/tree/treeModels.scala   |  36 ++-
 .../DecisionTreeClassifierSuite.scala           |  31 ++-
 .../ml/classification/GBTClassifierSuite.scala  |   4 +-
 .../RandomForestClassifierSuite.scala           |   5 +-
 .../regression/DecisionTreeRegressorSuite.scala |  14 ++
 .../spark/ml/tree/impl/RandomForestSuite.scala  |  22 +-
 .../apache/spark/ml/tree/impl/TreeTests.scala   |  12 +-
 project/MimaExcludes.scala                      |   9 +-
 16 files changed, 333 insertions(+), 108 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 65cce69..771cd4f 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -165,7 +165,7 @@ object DecisionTreeClassifier extends 
DefaultParamsReadable[DecisionTreeClassifi
 @Since("1.4.0")
 class DecisionTreeClassificationModel private[ml] (
     @Since("1.4.0")override val uid: String,
-    @Since("1.4.0")override val rootNode: Node,
+    @Since("1.4.0")override val rootNode: ClassificationNode,
     @Since("1.6.0")override val numFeatures: Int,
     @Since("1.5.0")override val numClasses: Int)
   extends ProbabilisticClassificationModel[Vector, 
DecisionTreeClassificationModel]
@@ -178,7 +178,7 @@ class DecisionTreeClassificationModel private[ml] (
    * Construct a decision tree classification model.
    * @param rootNode  Root node of tree, with other nodes attached.
    */
-  private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
+  private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, 
numClasses: Int) =
     this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
 
   override def predict(features: Vector): Double = {
@@ -276,8 +276,9 @@ object DecisionTreeClassificationModel extends 
MLReadable[DecisionTreeClassifica
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
       val numClasses = (metadata.metadata \ "numClasses").extract[Int]
-      val root = loadTreeNodes(path, metadata, sparkSession)
-      val model = new DecisionTreeClassificationModel(metadata.uid, root, 
numFeatures, numClasses)
+      val root = loadTreeNodes(path, metadata, sparkSession, isClassification 
= true)
+      val model = new DecisionTreeClassificationModel(metadata.uid,
+        root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
       DefaultParamsReader.getAndSetParams(model, metadata)
       model
     }
@@ -292,9 +293,10 @@ object DecisionTreeClassificationModel extends 
MLReadable[DecisionTreeClassifica
     require(oldModel.algo == OldAlgo.Classification,
       s"Cannot convert non-classification DecisionTreeModel (old API) to" +
         s" DecisionTreeClassificationModel (new API).  Algo is: 
${oldModel.algo}")
-    val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+    val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, 
isClassification = true)
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
     // Can't infer number of features from old model, so default to -1
-    new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
+    new DecisionTreeClassificationModel(uid,
+      rootNode.asInstanceOf[ClassificationNode], numFeatures, -1)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index cd44489..c025510 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -371,14 +371,14 @@ object GBTClassificationModel extends 
MLReadable[GBTClassificationModel] {
     override def load(path: String): GBTClassificationModel = {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], 
treeWeights: Array[Double]) =
-        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName)
+        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName, false)
       val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
       val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
 
       val trees: Array[DecisionTreeRegressionModel] = treesData.map {
         case (treeMetadata, root) =>
-          val tree =
-            new DecisionTreeRegressionModel(treeMetadata.uid, root, 
numFeatures)
+          val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
+            root.asInstanceOf[RegressionNode], numFeatures)
           DefaultParamsReader.getAndSetParams(tree, treeMetadata)
           tree
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 78a4972..bb972e9 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -310,15 +310,15 @@ object RandomForestClassificationModel extends 
MLReadable[RandomForestClassifica
     override def load(path: String): RandomForestClassificationModel = {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
-        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName)
+        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName, true)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
       val numClasses = (metadata.metadata \ "numClasses").extract[Int]
       val numTrees = (metadata.metadata \ "numTrees").extract[Int]
 
       val trees: Array[DecisionTreeClassificationModel] = treesData.map {
         case (treeMetadata, root) =>
-          val tree =
-            new DecisionTreeClassificationModel(treeMetadata.uid, root, 
numFeatures, numClasses)
+          val tree = new DecisionTreeClassificationModel(treeMetadata.uid,
+            root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
           DefaultParamsReader.getAndSetParams(tree, treeMetadata)
           tree
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index ad154fc..5cef5c9 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -160,7 +160,7 @@ object DecisionTreeRegressor extends 
DefaultParamsReadable[DecisionTreeRegressor
 @Since("1.4.0")
 class DecisionTreeRegressionModel private[ml] (
     override val uid: String,
-    override val rootNode: Node,
+    override val rootNode: RegressionNode,
     override val numFeatures: Int)
   extends PredictionModel[Vector, DecisionTreeRegressionModel]
   with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with 
Serializable {
@@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] (
    * Construct a decision tree regression model.
    * @param rootNode  Root node of tree, with other nodes attached.
    */
-  private[ml] def this(rootNode: Node, numFeatures: Int) =
+  private[ml] def this(rootNode: RegressionNode, numFeatures: Int) =
     this(Identifiable.randomUID("dtr"), rootNode, numFeatures)
 
   override def predict(features: Vector): Double = {
@@ -279,8 +279,9 @@ object DecisionTreeRegressionModel extends 
MLReadable[DecisionTreeRegressionMode
       implicit val format = DefaultFormats
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
-      val root = loadTreeNodes(path, metadata, sparkSession)
-      val model = new DecisionTreeRegressionModel(metadata.uid, root, 
numFeatures)
+      val root = loadTreeNodes(path, metadata, sparkSession, isClassification 
= false)
+      val model = new DecisionTreeRegressionModel(metadata.uid,
+        root.asInstanceOf[RegressionNode], numFeatures)
       DefaultParamsReader.getAndSetParams(model, metadata)
       model
     }
@@ -295,8 +296,8 @@ object DecisionTreeRegressionModel extends 
MLReadable[DecisionTreeRegressionMode
     require(oldModel.algo == OldAlgo.Regression,
       s"Cannot convert non-regression DecisionTreeModel (old API) to" +
         s" DecisionTreeRegressionModel (new API).  Algo is: ${oldModel.algo}")
-    val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+    val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, 
isClassification = false)
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
-    new DecisionTreeRegressionModel(uid, rootNode, numFeatures)
+    new DecisionTreeRegressionModel(uid, 
rootNode.asInstanceOf[RegressionNode], numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 6569ff2..834aaa0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -302,15 +302,15 @@ object GBTRegressionModel extends 
MLReadable[GBTRegressionModel] {
     override def load(path: String): GBTRegressionModel = {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], 
treeWeights: Array[Double]) =
-        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName)
+        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName, false)
 
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
       val numTrees = (metadata.metadata \ "numTrees").extract[Int]
 
       val trees: Array[DecisionTreeRegressionModel] = treesData.map {
         case (treeMetadata, root) =>
-          val tree =
-            new DecisionTreeRegressionModel(treeMetadata.uid, root, 
numFeatures)
+          val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
+            root.asInstanceOf[RegressionNode], numFeatures)
           DefaultParamsReader.getAndSetParams(tree, treeMetadata)
           tree
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 2d59446..7f77398 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -269,13 +269,13 @@ object RandomForestRegressionModel extends 
MLReadable[RandomForestRegressionMode
     override def load(path: String): RandomForestRegressionModel = {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], 
treeWeights: Array[Double]) =
-        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName)
+        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName, false)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
       val numTrees = (metadata.metadata \ "numTrees").extract[Int]
 
       val trees: Array[DecisionTreeRegressionModel] = treesData.map { case 
(treeMetadata, root) =>
-        val tree =
-          new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+        val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
+          root.asInstanceOf[RegressionNode], numFeatures)
         DefaultParamsReader.getAndSetParams(tree, treeMetadata)
         tree
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index d30be45..0242bc76 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -17,14 +17,16 @@
 
 package org.apache.spark.ml.tree
 
+import org.apache.spark.annotation.Since
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
-import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats 
=> OldInformationGainStats, Node => OldNode, Predict => OldPredict}
+import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats 
=> OldInformationGainStats,
+  Node => OldNode, Predict => OldPredict}
 
 /**
  * Decision tree node interface.
  */
-sealed abstract class Node extends Serializable {
+sealed trait Node extends Serializable {
 
   // TODO: Add aggregate stats (once available).  This will happen after we 
move the DecisionTree
   //       code into the new API and deprecate the old API.  SPARK-3727
@@ -84,35 +86,86 @@ private[ml] object Node {
   /**
    * Create a new Node from the old Node format, recursively creating child 
nodes as needed.
    */
-  def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
+  def fromOld(
+      oldNode: OldNode,
+      categoricalFeatures: Map[Int, Int],
+      isClassification: Boolean): Node = {
     if (oldNode.isLeaf) {
       // TODO: Once the implementation has been moved to this API, then 
include sufficient
       //       statistics here.
-      new LeafNode(prediction = oldNode.predict.predict,
-        impurity = oldNode.impurity, impurityStats = null)
+      if (isClassification) {
+        new ClassificationLeafNode(prediction = oldNode.predict.predict,
+          impurity = oldNode.impurity, impurityStats = null)
+      } else {
+        new RegressionLeafNode(prediction = oldNode.predict.predict,
+          impurity = oldNode.impurity, impurityStats = null)
+      }
     } else {
       val gain = if (oldNode.stats.nonEmpty) {
         oldNode.stats.get.gain
       } else {
         0.0
       }
-      new InternalNode(prediction = oldNode.predict.predict, impurity = 
oldNode.impurity,
-        gain = gain, leftChild = fromOld(oldNode.leftNode.get, 
categoricalFeatures),
-        rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
-        split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
+      if (isClassification) {
+        new ClassificationInternalNode(prediction = oldNode.predict.predict,
+          impurity = oldNode.impurity, gain = gain,
+          leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true)
+            .asInstanceOf[ClassificationNode],
+          rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, 
true)
+            .asInstanceOf[ClassificationNode],
+          split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
+      } else {
+        new RegressionInternalNode(prediction = oldNode.predict.predict,
+          impurity = oldNode.impurity, gain = gain,
+          leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false)
+            .asInstanceOf[RegressionNode],
+          rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, 
false)
+            .asInstanceOf[RegressionNode],
+          split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
+      }
     }
   }
 }
 
-/**
- * Decision tree leaf node.
- * @param prediction  Prediction this node makes
- * @param impurity  Impurity measure at this node (for training data)
- */
-class LeafNode private[ml] (
-    override val prediction: Double,
-    override val impurity: Double,
-    override private[ml] val impurityStats: ImpurityCalculator) extends Node {
+@Since("2.4.0")
+sealed trait ClassificationNode extends Node {
+
+  /**
+   * Get count of training examples for specified label in this node
+   * @param label label number in the range [0, numClasses)
+   */
+  @Since("2.4.0")
+  def getLabelCount(label: Int): Double = {
+    require(label >= 0 && label < impurityStats.stats.length,
+      "label should be in the range between 0 (inclusive) " +
+      s"and ${impurityStats.stats.length} (exclusive).")
+    impurityStats.stats(label)
+  }
+}
+
+@Since("2.4.0")
+sealed trait RegressionNode extends Node {
+
+  /** Number of training data points in this node */
+  @Since("2.4.0")
+  def getCount: Double = impurityStats.stats(0)
+
+  /** Sum over training data points of the labels in this node */
+  @Since("2.4.0")
+  def getSum: Double = impurityStats.stats(1)
+
+  /** Sum over training data points of the square of the labels in this node */
+  @Since("2.4.0")
+  def getSumOfSquares: Double = impurityStats.stats(2)
+}
+
+@Since("2.4.0")
+sealed trait LeafNode extends Node {
+
+  /** Prediction this node makes. */
+  def prediction: Double
+
+  def impurity: Double
 
   override def toString: String =
     s"LeafNode(prediction = $prediction, impurity = $impurity)"
@@ -135,32 +188,58 @@ class LeafNode private[ml] (
 
   override private[ml] def maxSplitFeatureIndex(): Int = -1
 
+}
+
+/**
+ * Decision tree leaf node for classification.
+ */
+@Since("2.4.0")
+class ClassificationLeafNode private[ml] (
+    override val prediction: Double,
+    override val impurity: Double,
+    override private[ml] val impurityStats: ImpurityCalculator)
+  extends ClassificationNode with LeafNode {
+
   override private[tree] def deepCopy(): Node = {
-    new LeafNode(prediction, impurity, impurityStats)
+    new ClassificationLeafNode(prediction, impurity, impurityStats)
   }
 }
 
 /**
- * Internal Decision Tree node.
- * @param prediction  Prediction this node would make if it were a leaf node
- * @param impurity  Impurity measure at this node (for training data)
- * @param gain Information gain value. Values less than 0 indicate missing 
values;
- *             this quirk will be removed with future updates.
- * @param leftChild  Left-hand child node
- * @param rightChild  Right-hand child node
- * @param split  Information about the test used to split to the left or right 
child.
+ * Decision tree leaf node for regression.
  */
-class InternalNode private[ml] (
+@Since("2.4.0")
+class RegressionLeafNode private[ml] (
     override val prediction: Double,
     override val impurity: Double,
-    val gain: Double,
-    val leftChild: Node,
-    val rightChild: Node,
-    val split: Split,
-    override private[ml] val impurityStats: ImpurityCalculator) extends Node {
+    override private[ml] val impurityStats: ImpurityCalculator)
+  extends RegressionNode with LeafNode {
 
-  // Note to developers: The constructor argument impurityStats should be 
reconsidered before we
-  //                     make the constructor public.  We may be able to 
improve the representation.
+  override private[tree] def deepCopy(): Node = {
+    new RegressionLeafNode(prediction, impurity, impurityStats)
+  }
+}
+
+/**
+ * Internal Decision Tree node.
+ */
+@Since("2.4.0")
+sealed trait InternalNode extends Node {
+
+  /**
+   * Information gain value. Values less than 0 indicate missing values;
+   * this quirk will be removed with future updates.
+   */
+  def gain: Double
+
+  /** Left-hand child node */
+  def leftChild: Node
+
+  /** Right-hand child node */
+  def rightChild: Node
+
+  /** Information about the test used to split to the left or right child. */
+  def split: Split
 
   override def toString: String = {
     s"InternalNode(prediction = $prediction, impurity = $impurity, split = 
$split)"
@@ -205,11 +284,6 @@ class InternalNode private[ml] (
     math.max(split.featureIndex,
       math.max(leftChild.maxSplitFeatureIndex(), 
rightChild.maxSplitFeatureIndex()))
   }
-
-  override private[tree] def deepCopy(): Node = {
-    new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), 
rightChild.deepCopy(),
-      split, impurityStats)
-  }
 }
 
 private object InternalNode {
@@ -241,6 +315,57 @@ private object InternalNode {
 }
 
 /**
+ * Internal Decision Tree node for regression.
+ */
+@Since("2.4.0")
+class ClassificationInternalNode private[ml] (
+    override val prediction: Double,
+    override val impurity: Double,
+    override val gain: Double,
+    override val leftChild: ClassificationNode,
+    override val rightChild: ClassificationNode,
+    override val split: Split,
+    override private[ml] val impurityStats: ImpurityCalculator)
+  extends ClassificationNode with InternalNode {
+
+  // Note to developers: The constructor argument impurityStats should be 
reconsidered before we
+  //                     make the constructor public.  We may be able to 
improve the representation.
+
+  override private[tree] def deepCopy(): Node = {
+    new ClassificationInternalNode(prediction, impurity, gain,
+      leftChild.deepCopy().asInstanceOf[ClassificationNode],
+      rightChild.deepCopy().asInstanceOf[ClassificationNode],
+      split, impurityStats)
+  }
+}
+
+/**
+ * Internal Decision Tree node for regression.
+ */
+@Since("2.4.0")
+class RegressionInternalNode private[ml] (
+    override val prediction: Double,
+    override val impurity: Double,
+    override val gain: Double,
+    override val leftChild: RegressionNode,
+    override val rightChild: RegressionNode,
+    override val split: Split,
+    override private[ml] val impurityStats: ImpurityCalculator)
+  extends RegressionNode with InternalNode {
+
+  // Note to developers: The constructor argument impurityStats should be 
reconsidered before we
+  //                     make the constructor public.  We may be able to 
improve the representation.
+
+  override private[tree] def deepCopy(): Node = {
+    new RegressionInternalNode(prediction, impurity, gain,
+      leftChild.deepCopy().asInstanceOf[RegressionNode],
+      rightChild.deepCopy().asInstanceOf[RegressionNode],
+      split, impurityStats)
+  }
+}
+
+
+/**
  * Version of a node used in learning.  This uses vars so that we can modify 
nodes as we split the
  * tree by adding children, etc.
  *
@@ -265,30 +390,52 @@ private[tree] class LearningNode(
     var isLeaf: Boolean,
     var stats: ImpurityStats) extends Serializable {
 
-  def toNode: Node = toNode(prune = true)
+  def toNode(isClassification: Boolean): Node = toNode(isClassification, prune 
= true)
+
+  def toClassificationNode(prune: Boolean = true): ClassificationNode = {
+    toNode(true, prune).asInstanceOf[ClassificationNode]
+  }
+
+  def toRegressionNode(prune: Boolean = true): RegressionNode = {
+    toNode(false, prune).asInstanceOf[RegressionNode]
+  }
 
   /**
    * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any 
children.
    */
-  def toNode(prune: Boolean = true): Node = {
+  def toNode(isClassification: Boolean, prune: Boolean): Node = {
 
     if (!leftChild.isEmpty || !rightChild.isEmpty) {
       assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && 
stats != null,
         "Unknown error during Decision Tree learning.  Could not convert 
LearningNode to Node.")
-      (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match {
+      (leftChild.get.toNode(isClassification, prune),
+       rightChild.get.toNode(isClassification, prune)) match {
         case (l: LeafNode, r: LeafNode) if prune && l.prediction == 
r.prediction =>
-          new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
+          if (isClassification) {
+            new ClassificationLeafNode(l.prediction, stats.impurity, 
stats.impurityCalculator)
+          } else {
+            new RegressionLeafNode(l.prediction, stats.impurity, 
stats.impurityCalculator)
+          }
         case (l, r) =>
-          new InternalNode(stats.impurityCalculator.predict, stats.impurity, 
stats.gain,
-            l, r, split.get, stats.impurityCalculator)
+          if (isClassification) {
+            new ClassificationInternalNode(stats.impurityCalculator.predict, 
stats.impurity,
+              stats.gain, l.asInstanceOf[ClassificationNode], 
r.asInstanceOf[ClassificationNode],
+              split.get, stats.impurityCalculator)
+          } else {
+            new RegressionInternalNode(stats.impurityCalculator.predict, 
stats.impurity, stats.gain,
+              l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode],
+              split.get, stats.impurityCalculator)
+          }
       }
     } else {
-      if (stats.valid) {
-        new LeafNode(stats.impurityCalculator.predict, stats.impurity,
+      // Here we want to keep same behavior with the old 
mllib.DecisionTreeModel
+      val impurity = if (stats.valid) stats.impurity else -1.0
+      if (isClassification) {
+        new ClassificationLeafNode(stats.impurityCalculator.predict, impurity,
           stats.impurityCalculator)
       } else {
-        // Here we want to keep same behavior with the old 
mllib.DecisionTreeModel
-        new LeafNode(stats.impurityCalculator.predict, -1.0, 
stats.impurityCalculator)
+        new RegressionLeafNode(stats.impurityCalculator.predict, impurity,
+          stats.impurityCalculator)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 16f32d7..056a94b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -224,23 +224,23 @@ private[spark] object RandomForest extends Logging {
       case Some(uid) =>
         if (strategy.algo == OldAlgo.Classification) {
           topNodes.map { rootNode =>
-            new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), 
numFeatures,
-              strategy.getNumClasses)
+            new DecisionTreeClassificationModel(uid, 
rootNode.toClassificationNode(prune),
+              numFeatures, strategy.getNumClasses)
           }
         } else {
           topNodes.map { rootNode =>
-            new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), 
numFeatures)
+            new DecisionTreeRegressionModel(uid, 
rootNode.toRegressionNode(prune), numFeatures)
           }
         }
       case None =>
         if (strategy.algo == OldAlgo.Classification) {
           topNodes.map { rootNode =>
-            new DecisionTreeClassificationModel(rootNode.toNode(prune), 
numFeatures,
+            new 
DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), 
numFeatures,
               strategy.getNumClasses)
           }
         } else {
           topNodes.map(rootNode =>
-            new DecisionTreeRegressionModel(rootNode.toNode(prune), 
numFeatures))
+            new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), 
numFeatures))
         }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 4aa4c36..f027b14 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -219,8 +219,10 @@ private[ml] object TreeEnsembleModel {
         importances.changeValue(feature, scaledGain, _ + scaledGain)
         computeFeatureImportance(n.leftChild, importances)
         computeFeatureImportance(n.rightChild, importances)
-      case n: LeafNode =>
+      case _: LeafNode =>
       // do nothing
+      case _ =>
+        throw new IllegalArgumentException(s"Unknown node type: 
${node.getClass.toString}")
     }
   }
 
@@ -317,6 +319,8 @@ private[ml] object DecisionTreeModelReadWrite {
         (Seq(NodeData(id, node.prediction, node.impurity, 
node.impurityStats.stats,
           -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
           id)
+      case _ =>
+        throw new IllegalArgumentException(s"Unknown node type: 
${node.getClass.toString}")
     }
   }
 
@@ -327,7 +331,7 @@ private[ml] object DecisionTreeModelReadWrite {
   def loadTreeNodes(
       path: String,
       metadata: DefaultParamsReader.Metadata,
-      sparkSession: SparkSession): Node = {
+      sparkSession: SparkSession, isClassification: Boolean): Node = {
     import sparkSession.implicits._
     implicit val format = DefaultFormats
 
@@ -339,7 +343,7 @@ private[ml] object DecisionTreeModelReadWrite {
 
     val dataPath = new Path(path, "data").toString
     val data = sparkSession.read.parquet(dataPath).as[NodeData]
-    buildTreeFromNodes(data.collect(), impurityType)
+    buildTreeFromNodes(data.collect(), impurityType, isClassification)
   }
 
   /**
@@ -348,7 +352,8 @@ private[ml] object DecisionTreeModelReadWrite {
    * @param impurityType  Impurity type for this tree
    * @return Root node of reconstructed tree
    */
-  def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
+  def buildTreeFromNodes(data: Array[NodeData], impurityType: String,
+      isClassification: Boolean): Node = {
     // Load all nodes, sorted by ID.
     val nodes = data.sortBy(_.id)
     // Sanity checks; could remove
@@ -364,10 +369,21 @@ private[ml] object DecisionTreeModelReadWrite {
       val node = if (n.leftChild != -1) {
         val leftChild = finalNodes(n.leftChild)
         val rightChild = finalNodes(n.rightChild)
-        new InternalNode(n.prediction, n.impurity, n.gain, leftChild, 
rightChild,
-          n.split.getSplit, impurityStats)
+        if (isClassification) {
+          new ClassificationInternalNode(n.prediction, n.impurity, n.gain,
+            leftChild.asInstanceOf[ClassificationNode], 
rightChild.asInstanceOf[ClassificationNode],
+            n.split.getSplit, impurityStats)
+        } else {
+          new RegressionInternalNode(n.prediction, n.impurity, n.gain,
+            leftChild.asInstanceOf[RegressionNode], 
rightChild.asInstanceOf[RegressionNode],
+            n.split.getSplit, impurityStats)
+        }
       } else {
-        new LeafNode(n.prediction, n.impurity, impurityStats)
+        if (isClassification) {
+          new ClassificationLeafNode(n.prediction, n.impurity, impurityStats)
+        } else {
+          new RegressionLeafNode(n.prediction, n.impurity, impurityStats)
+        }
       }
       finalNodes(n.id) = node
     }
@@ -421,7 +437,8 @@ private[ml] object EnsembleModelReadWrite {
       path: String,
       sql: SparkSession,
       className: String,
-      treeClassName: String): (Metadata, Array[(Metadata, Node)], 
Array[Double]) = {
+      treeClassName: String,
+      isClassification: Boolean): (Metadata, Array[(Metadata, Node)], 
Array[Double]) = {
     import sql.implicits._
     implicit val format = DefaultFormats
     val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, 
className)
@@ -449,7 +466,8 @@ private[ml] object EnsembleModelReadWrite {
     val rootNodesRDD: RDD[(Int, Node)] =
       nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
         case (treeID: Int, nodeData: Iterable[NodeData]) =>
-          treeID -> 
DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
+          treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(
+            nodeData.toArray, impurityType, isClassification)
       }
     val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
     (metadata, treesMetadata.zip(rootNodes), treesWeights)

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 2930f49..d3dbb4e 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.tree.ClassificationLeafNode
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
@@ -61,7 +61,8 @@ class DecisionTreeClassifierSuite extends MLTest with 
DefaultReadWriteTest {
 
   test("params") {
     ParamsSuite.checkParams(new DecisionTreeClassifier)
-    val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 
0.0, null), 1, 2)
+    val model = new DecisionTreeClassificationModel("dtc",
+      new ClassificationLeafNode(0.0, 0.0, null), 1, 2)
     ParamsSuite.checkParams(model)
   }
 
@@ -375,6 +376,32 @@ class DecisionTreeClassifierSuite extends MLTest with 
DefaultReadWriteTest {
 
     testDefaultReadWrite(model)
   }
+
+  test("label/impurity stats") {
+    val arr = Array(
+      LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+      LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+      LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
+    val rdd = sc.parallelize(arr)
+    val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2)
+    val dt1 = new DecisionTreeClassifier()
+      .setImpurity("entropy")
+      .setMaxDepth(2)
+      .setMinInstancesPerNode(2)
+    val model1 = dt1.fit(df)
+
+    val rootNode1 = model1.rootNode
+    assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === 
Array(2.0, 1.0))
+
+    val dt2 = new DecisionTreeClassifier()
+      .setImpurity("gini")
+      .setMaxDepth(2)
+      .setMinInstancesPerNode(2)
+    val model2 = dt2.fit(df)
+
+    val rootNode2 = model2.rootNode
+    assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === 
Array(2.0, 1.0))
+  }
 }
 
 private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 5779606..f0ee549 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.tree.RegressionLeafNode
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
@@ -69,7 +69,7 @@ class GBTClassifierSuite extends MLTest with 
DefaultReadWriteTest {
   test("params") {
     ParamsSuite.checkParams(new GBTClassifier)
     val model = new GBTClassificationModel("gbtc",
-      Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, 
null), 1)),
+      Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 
0.0, null), 1)),
       Array(1.0), 1, 2)
     ParamsSuite.checkParams(model)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index ba4a9cf..3062aa9 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.tree.ClassificationLeafNode
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
@@ -71,7 +71,8 @@ class RandomForestClassifierSuite extends MLTest with 
DefaultReadWriteTest {
   test("params") {
     ParamsSuite.checkParams(new RandomForestClassifier)
     val model = new RandomForestClassificationModel("rfc",
-      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, 
null), 1, 2)), 2, 2)
+      Array(new DecisionTreeClassificationModel("dtc",
+        new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2)
     ParamsSuite.checkParams(model)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 29a4383..9ae2733 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -191,6 +191,20 @@ class DecisionTreeRegressorSuite extends MLTest with 
DefaultReadWriteTest {
       TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
       TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
   }
+
+  test("label/impurity stats") {
+    val categoricalFeatures = Map(0 -> 2, 1 -> 2)
+    val df = TreeTests.setMetadata(categoricalDataPointsRDD, 
categoricalFeatures, numClasses = 0)
+    val dtr = new DecisionTreeRegressor()
+      .setImpurity("variance")
+      .setMaxDepth(2)
+      .setMaxBins(8)
+    val model = dtr.fit(df)
+    val statInfo = model.rootNode
+
+    assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0
+      && statInfo.getSumOfSquares == 600.0)
+  }
 }
 
 private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 743dacf..4dbbd75 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -340,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(topNode.stats.impurity > 0.0)
 
     // set impurity and predict for child nodes
-    assert(topNode.leftChild.get.toNode.prediction === 0.0)
-    assert(topNode.rightChild.get.toNode.prediction === 1.0)
+    assert(topNode.leftChild.get.toNode(isClassification = true).prediction 
=== 0.0)
+    assert(topNode.rightChild.get.toNode(isClassification = true).prediction 
=== 1.0)
     assert(topNode.leftChild.get.stats.impurity === 0.0)
     assert(topNode.rightChild.get.stats.impurity === 0.0)
   }
@@ -382,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(topNode.stats.impurity > 0.0)
 
     // set impurity and predict for child nodes
-    assert(topNode.leftChild.get.toNode.prediction === 0.0)
-    assert(topNode.rightChild.get.toNode.prediction === 1.0)
+    assert(topNode.leftChild.get.toNode(isClassification = true).prediction 
=== 0.0)
+    assert(topNode.rightChild.get.toNode(isClassification = true).prediction 
=== 1.0)
     assert(topNode.leftChild.get.stats.impurity === 0.0)
     assert(topNode.rightChild.get.stats.impurity === 0.0)
   }
@@ -582,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
                 left  right
      */
     val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
-    val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
+    val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp)
 
     val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
-    val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
+    val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp)
 
-    val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 
0.5))
+    val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 
0.5), true)
     val parentImp = parent.impurityStats
 
     val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
-    val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
+    val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp)
 
-    val grandParent = TreeTests.buildParentNode(left2, parent, new 
ContinuousSplit(1, 1.0))
+    val grandParent = TreeTests.buildParentNode(left2, parent, new 
ContinuousSplit(1, 1.0), true)
     val grandImp = grandParent.impurityStats
 
     // Test feature importance computed at different subtrees.
@@ -618,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 
     // Forest consisting of (full tree) + (internal node with 2 leafs)
     val trees = Array(parent, grandParent).map { root =>
-      new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 
3)
-        .asInstanceOf[DecisionTreeModel]
+      new 
DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode],
+        numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel]
     }
     val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2)
     val tree2norm = feature0importance + feature1importance

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index b6894b3..3f03d90 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite {
    * @param split  Split for parent node
    * @return  Parent node with children attached
    */
-  def buildParentNode(left: Node, right: Node, split: Split): Node = {
+  def buildParentNode(left: Node, right: Node, split: Split, isClassification: 
Boolean): Node = {
     val leftImp = left.impurityStats
     val rightImp = right.impurityStats
     val parentImp = leftImp.copy.add(rightImp)
@@ -168,7 +168,15 @@ private[ml] object TreeTests extends SparkFunSuite {
     val gain = parentImp.calculate() -
       (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
     val pred = parentImp.predict
-    new InternalNode(pred, parentImp.calculate(), gain, left, right, split, 
parentImp)
+    if (isClassification) {
+      new ClassificationInternalNode(pred, parentImp.calculate(), gain,
+        left.asInstanceOf[ClassificationNode], 
right.asInstanceOf[ClassificationNode],
+        split, parentImp)
+    } else {
+      new RegressionInternalNode(pred, parentImp.calculate(), gain,
+        left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode],
+        split, parentImp)
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/252468a7/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 1b6d1de..b37b4d5 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -55,7 +55,14 @@ object MimaExcludes {
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"),
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"),
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"),
-    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel")
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"),
+
+    // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision 
tree nodes
+    
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"),
+    
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"),
+    
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"),
+    
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"),
+    
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this")
   )
 
   // Exclude rules for 2.3.x


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to