Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/20632#discussion_r170412098 --- Diff: mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala --- @@ -631,6 +651,160 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } + + test("[SPARK-3159] tree model redundancy - binary classification") { + val numClasses = 2 + + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val dt = buildRedundantDecisionTree(numClasses, 20, strategy = strategy) + + /* Expected tree structure tested below: + root + left1 right1 + left2 right2 + + pred(left1) = 0 + pred(left2) = 1 + pred(right2) = 0 + */ + assert(dt.rootNode.numDescendants === 4) + assert(dt.rootNode.subtreeDepth === 2) + + assert(dt.rootNode.isInstanceOf[InternalNode]) + + // left 1 prediction test + assert(dt.rootNode.asInstanceOf[InternalNode].leftChild.prediction === 0) + + val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild + assert(right1.isInstanceOf[InternalNode]) + + // left 2 prediction test + assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 1) + // right 2 prediction test + assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 0) + } + + test("[SPARK-3159] tree model redundancy - multiclass classification") { + val numClasses = 4 + + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val dt = buildRedundantDecisionTree(numClasses, 20, strategy = strategy) + + /* Expected tree structure tested below: + root + left1 right1 + left2 right2 left3 right3 + + pred(left2) = 0 + pred(right2) = 1 + pred(left3) = 2 + pred(right3) = 1 + */ + assert(dt.rootNode.numDescendants === 6) + assert(dt.rootNode.subtreeDepth === 2) + + assert(dt.rootNode.isInstanceOf[InternalNode]) + + val left1 = dt.rootNode.asInstanceOf[InternalNode].leftChild + val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild + + assert(left1.isInstanceOf[InternalNode]) + + // left 2 prediction test + assert(left1.asInstanceOf[InternalNode].leftChild.prediction === 0) + // right 2 prediction test + assert(left1.asInstanceOf[InternalNode].rightChild.prediction === 1) + + assert(right1.isInstanceOf[InternalNode]) + + // left 3 prediction test + assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 2) + // right 3 prediction test + assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 1) + } + + test("[SPARK-3159] tree model redundancy - regression") { + val numClasses = 2 + + val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, + maxDepth = 3, maxBins = 10, numClasses = numClasses) + + val dt = buildRedundantDecisionTree(numClasses, 20, strategy = strategy) + + /* Expected tree structure tested below: + root + 1 2 + 1_1 1_2 2_1 2_2 + 1_1_1 1_1_2 1_2_1 1_2_2 2_1_1 2_1_2 + + pred(1_1_1) = 0.5 + pred(1_1_2) = 0.0 + pred(1_2_1) = 0.0 + pred(1_2_2) = 0.25 + pred(2_1_1) = 1.0 + pred(2_1_2) = 0.6666666666666666 + pred(2_2) = 0.5 + */ + + assert(dt.rootNode.numDescendants === 12) --- End diff -- The tree tests are already so long and complicated that I think it's important to simplify where possible. These tests are useful as they are, but it probably won't be obvious why/how they work to future devs. Also, if we can avoid adding data generation code, that would be nice (there's already tons of code like that laying around the test suites).
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org