Github user sethah commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20632#discussion_r170412098
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
    @@ -631,6 +651,160 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
         val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
         assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
       }
    +
    +  test("[SPARK-3159] tree model redundancy - binary classification") {
    +    val numClasses = 2
    +
    +    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
    +      numClasses = numClasses, maxBins = 32)
    +
    +    val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
    +
    +    /* Expected tree structure tested below:
    +             root
    +      left1       right1
    +                left2  right2
    +
    +      pred(left1)  = 0
    +      pred(left2)  = 1
    +      pred(right2) = 0
    +     */
    +    assert(dt.rootNode.numDescendants === 4)
    +    assert(dt.rootNode.subtreeDepth === 2)
    +
    +    assert(dt.rootNode.isInstanceOf[InternalNode])
    +
    +    // left 1 prediction test
    +    assert(dt.rootNode.asInstanceOf[InternalNode].leftChild.prediction === 
0)
    +
    +    val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild
    +    assert(right1.isInstanceOf[InternalNode])
    +
    +    // left 2 prediction test
    +    assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 1)
    +    // right 2 prediction test
    +    assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 0)
    +  }
    +
    +  test("[SPARK-3159] tree model redundancy - multiclass classification") {
    +    val numClasses = 4
    +
    +    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
    +      numClasses = numClasses, maxBins = 32)
    +
    +    val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
    +
    +    /* Expected tree structure tested below:
    +                    root
    +            left1         right1
    +        left2  right2  left3  right3
    +
    +      pred(left2)  = 0
    +      pred(right2)  = 1
    +      pred(left3) = 2
    +      pred(right3) = 1
    +     */
    +    assert(dt.rootNode.numDescendants === 6)
    +    assert(dt.rootNode.subtreeDepth === 2)
    +
    +    assert(dt.rootNode.isInstanceOf[InternalNode])
    +
    +    val left1 = dt.rootNode.asInstanceOf[InternalNode].leftChild
    +    val right1 = dt.rootNode.asInstanceOf[InternalNode].rightChild
    +
    +    assert(left1.isInstanceOf[InternalNode])
    +
    +    // left 2 prediction test
    +    assert(left1.asInstanceOf[InternalNode].leftChild.prediction === 0)
    +    // right 2 prediction test
    +    assert(left1.asInstanceOf[InternalNode].rightChild.prediction === 1)
    +
    +    assert(right1.isInstanceOf[InternalNode])
    +
    +    // left 3 prediction test
    +    assert(right1.asInstanceOf[InternalNode].leftChild.prediction === 2)
    +    // right 3 prediction test
    +    assert(right1.asInstanceOf[InternalNode].rightChild.prediction === 1)
    +  }
    +
    +  test("[SPARK-3159] tree model redundancy - regression") {
    +    val numClasses = 2
    +
    +    val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = 
Variance,
    +      maxDepth = 3, maxBins = 10, numClasses = numClasses)
    +
    +    val dt = buildRedundantDecisionTree(numClasses, 20, strategy = 
strategy)
    +
    +    /* Expected tree structure tested below:
    +                                root
    +                    1                             2
    +           1_1             1_2             2_1         2_2
    +      1_1_1   1_1_2   1_2_1   1_2_2   2_1_1   2_1_2
    +
    +      pred(1_1_1)  = 0.5
    +      pred(1_1_2)  = 0.0
    +      pred(1_2_1)  = 0.0
    +      pred(1_2_2)  = 0.25
    +      pred(2_1_1)  = 1.0
    +      pred(2_1_2)  = 0.6666666666666666
    +      pred(2_2)    = 0.5
    +     */
    +
    +    assert(dt.rootNode.numDescendants === 12)
    --- End diff --
    
    The tree tests are already so long and complicated that I think it's 
important to simplify where possible. These tests are useful as they are, but 
it probably won't be obvious why/how they work to future devs. Also, if we can 
avoid adding data generation code, that would be nice (there's already tons of 
code like that laying around the test suites). 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to