Github user WeichenXu123 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19666#discussion_r149274660
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
    @@ -976,6 +930,44 @@ private[spark] object RandomForest extends Logging {
         categories
       }
     
    +  private[tree] def traverseUnorderedSplits[T](
    +      arity: Int,
    +      zeroStats: T,
    +      seqOp: (T, Int) => T,
    +      finalizer: (BitSet, T) => Unit): Unit = {
    +    assert(arity > 1)
    +
    +    // numSplits = (1 << arity - 1) - 1
    +    val numSplits = DecisionTreeMetadata.numUnorderedSplits(arity)
    +    val subSet: BitSet = new BitSet(arity)
    +
    +    // dfs traverse
    +    // binIndex: [0, arity)
    +    def dfs(binIndex: Int, combNumber: Int, stats: T): Unit = {
    +      if (binIndex == arity) {
    +        // recursion exit when binIndex == arity
    +        if (combNumber > 0) {
    +          // we get an available unordered split, saved in subSet.
    +          finalizer(subSet, stats)
    +        }
    +      } else {
    +        subSet.set(binIndex)
    +        val leftChildCombNumber = combNumber + (1 << binIndex)
    +        // pruning: only need combNumber satisfy: 1 <= combNumber <= 
numSplits
    --- End diff --
    
    Yes. for example: "00101" and "11010" they're equivalent splits, we should 
traverse only one of them.
    So here I use the condition `1 <= combNumber <= numSplits` to do the 
pruning. It can simply filter out another half splits.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to