Github user srowen commented on a diff in the pull request: https://github.com/apache/spark/pull/20632#discussion_r169092238 --- Diff: mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala --- @@ -402,20 +406,40 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(2.0))) val input = sc.parallelize(arr) + val seed = 42 + val numTrees = 1 + // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) - val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head - model.rootNode match { - case n: InternalNode => n.split match { - case s: CategoricalSplit => - assert(s.leftCategories === Array(1.0)) - case _ => throw new AssertionError("model.rootNode.split was not a CategoricalSplit") - } - case _ => throw new AssertionError("model.rootNode was not an InternalNode") - } + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy, numTrees = numTrees, + featureSubsetStrategy = "all") + val splits = RandomForest.findSplits(input, metadata, seed = seed) + + val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, + strategy.subsamplingRate, numTrees, false, seed = seed) + + val topNode = LearningNode.emptyNode(nodeIndex = 1) + assert(topNode.isLeaf === false) + assert(topNode.stats === null) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( --- End diff -- Nit: I think `Map(a -> b)` syntax is clearer than `Map((a, b))` syntax. (You use the former a few lines later too.)
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org