Github user sethah commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20632#discussion_r170738944
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
    @@ -631,10 +634,99 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
         val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
         assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
       }
    +
    +  
///////////////////////////////////////////////////////////////////////////////
    +  // Tests for pruning of redundant subtrees (generated by a split 
improving the
    +  // impurity measure, but always leading to the same prediction).
    +  
///////////////////////////////////////////////////////////////////////////////
    +
    +  test("[SPARK-3159] tree model redundancy - binary classification") {
    +    // The following dataset is set up such that splitting over feature 1 
for points having
    +    // feature 0 = 0 improves the impurity measure, despite the prediction 
will always be 0
    +    // in both branches.
    +    val arr = Array(
    +      LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
    +      LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
    +      LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
    +      LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
    +      LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
    +      LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
    +    )
    +    val rdd = sc.parallelize(arr)
    +
    +    val numClasses = 2
    +    val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 4,
    +      numClasses = numClasses, maxBins = 32)
    +
    +    val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
    +      seed = 42, instr = None).head
    +
    +    val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, 
featureSubsetStrategy = "auto",
    +      seed = 42, instr = None, prune = false).head
    +
    +    assert(prunedTree.numNodes === 5)
    +    assert(unprunedTree.numNodes === 7)
    +  }
    +
    +  test("[SPARK-3159] tree model redundancy - multiclass classification") {
    --- End diff --
    
    Why do we need to test binary and multiclass separately?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to