Github user WeichenXu123 commented on a diff in the pull request: https://github.com/apache/spark/pull/19666#discussion_r149274660 --- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala --- @@ -976,6 +930,44 @@ private[spark] object RandomForest extends Logging { categories } + private[tree] def traverseUnorderedSplits[T]( + arity: Int, + zeroStats: T, + seqOp: (T, Int) => T, + finalizer: (BitSet, T) => Unit): Unit = { + assert(arity > 1) + + // numSplits = (1 << arity - 1) - 1 + val numSplits = DecisionTreeMetadata.numUnorderedSplits(arity) + val subSet: BitSet = new BitSet(arity) + + // dfs traverse + // binIndex: [0, arity) + def dfs(binIndex: Int, combNumber: Int, stats: T): Unit = { + if (binIndex == arity) { + // recursion exit when binIndex == arity + if (combNumber > 0) { + // we get an available unordered split, saved in subSet. + finalizer(subSet, stats) + } + } else { + subSet.set(binIndex) + val leftChildCombNumber = combNumber + (1 << binIndex) + // pruning: only need combNumber satisfy: 1 <= combNumber <= numSplits --- End diff -- Yes. for example: "00101" and "11010" they're equivalent splits, we should traverse only one of them. So here I use the condition `1 <= combNumber <= numSplits` to do the pruning. It can simply filter out another half splits.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org