Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/14872#discussion_r76853239
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala ---
    @@ -0,0 +1,575 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tree.impl
    +
    +import org.roaringbitmap.RoaringBitmap
    +
    +import org.apache.spark.ml.tree._
    +import org.apache.spark.mllib.tree.model.ImpurityStats
    +import org.apache.spark.util.collection.BitSet
    +
    +/** Object exposing methods for local training of decision trees */
    +private[ml] object LocalDecisionTree {
    +
    +  /**
    +   * Fully splits the passed-in node on the provided local dataset.
    +   * TODO(smurching): Accept a seed for feature subsampling
    +   *
    +   * @param node LearningNode to split
    +   */
    +  def fitNode(
    +      input: Array[BaggedPoint[TreePoint]],
    +      node: LearningNode,
    +      metadata: DecisionTreeMetadata,
    +      splits: Array[Array[Split]]): Node = {
    +
    +    // The case with 1 node (depth = 0) is handled separately.
    +    // This allows all iterations in the depth > 0 case to use the same 
code.
    +    // TODO: Check that learning works when maxDepth > 0 but learning 
stops at 1 node (because of
    +    //       other parameters).
    +    if (metadata.maxDepth == 0) {
    +      val impurityAggregator: ImpurityAggregatorSingle =
    +        input.aggregate(metadata.createImpurityAggregator())(
    +          (agg, lp) => agg.update(lp.datum.label, 1.0),
    +          (agg1, agg2) => agg1.add(agg2))
    +      val impurityCalculator = impurityAggregator.getCalculator
    +      return new 
LeafNode(LocalDecisionTreeUtils.getPredict(impurityCalculator).predict,
    +        impurityCalculator.calculate(), impurityCalculator)
    +    }
    +
    +    // Prepare column store.
    +    //   Note: rowToColumnStoreDense checks to make sure numRows < 
Int.MaxValue.
    +    val colStoreInit: Array[Array[Int]]
    +      = 
LocalDecisionTreeUtils.rowToColumnStoreDense(input.map(_.datum.binnedFeatures))
    +    val labels = input.map(_.datum.label)
    +
    +    // Train classifier if numClasses is between 1 and 32, otherwise fit a 
regression model
    +    // on the dataset
    +    if (metadata.numClasses > 1 && metadata.numClasses <= 32) {
    +      throw new UnsupportedOperationException("Local training of a 
decision tree classifier is" +
    +        "unsupported; currently, only regression is supported")
    +    } else {
    +      // TODO(smurching): Pass an array of instanceWeights extracted from 
the input BaggedPoint?
    +      // Also, pass seed for feature subsampling
    +      trainRegressor(node, colStoreInit, labels, metadata, splits)
    +    }
    +  }
    +
    +  /**
    +   * Locally fits a decision tree regressor.
    +   * TODO(smurching): Logic for fitting a classifier & regressor seems the 
same; only difference
    +   * is impurity metric. Use the same logic for fitting a classifier?
    +   *
    +   * @param rootNode Node to fit on the passed-in dataset
    +   * @param colStoreInit Array of columns of training data
    +   * @param metadata Metadata object
    +   * @param splits splits(i) = Array of possible splits for feature i
    +   * @return
    +   */
    +  def trainRegressor(
    +      rootNode: LearningNode,
    +      colStoreInit: Array[Array[Int]],
    +      labels: Array[Double],
    +      metadata: DecisionTreeMetadata,
    +      splits: Array[Array[Split]]): Node = {
    +
    +    // Sort each column by feature values.
    +    val colStore: Array[FeatureVector] = colStoreInit.zipWithIndex.map { 
case (col, featureIndex) =>
    +      val featureArity: Int = 
metadata.featureArity.getOrElse(featureIndex, 0)
    +      FeatureVector.fromOriginal(featureIndex, featureArity, col)
    +    }
    +
    +    val numRows = colStore.headOption match {
    +      case None => 0
    +      case Some(column) => column.values.size
    +    }
    +
    +    // Create an impurityAggregator object containing info for 1 node (the 
root node).
    +    val fullImpurityAgg = metadata.createImpurityAggregator()
    +    labels.foreach(fullImpurityAgg.update(_))
    +
    +    // Create a bitset describing the set of active (non-leaf) nodes; 
initially, only the
    +    // root node is active
    +    val initActive = new BitSet(1)
    +    initActive.set(0)
    +    var partitionInfo: PartitionInfo = new PartitionInfo(colStore,
    +      Array[Int](0, numRows), initActive, Array(fullImpurityAgg))
    +
    +    // Initialize model.
    +    // Note: We do not use node indices.
    +    // Active nodes (still being split), updated each iteration
    +    var activeNodePeriphery: Array[LearningNode] = Array(rootNode)
    +    var numNodeOffsets: Int = 2
    +
    +    // Iteratively learn, one level of the tree at a time.
    +    var currentLevel = 0
    +    var doneLearning = false
    +    while (currentLevel < metadata.maxDepth && !doneLearning) {
    +      // Compute best split for each active node.
    +      val bestSplitsAndGains: Array[(Option[Split], ImpurityStats)] =
    +        computeBestSplits(partitionInfo, labels, metadata, splits)
    +      /*
    +      // NOTE: The actual active nodes (activeNodePeriphery) may be a 
subset of the nodes under
    +      //       bestSplitsAndGains since
    +      assert(activeNodePeriphery.length == bestSplitsAndGains.length,
    +        s"activeNodePeriphery.length=${activeNodePeriphery.length} does 
not equal" +
    +          s" bestSplitsAndGains.length=${bestSplitsAndGains.length}")
    +      */
    +
    +      // Update current model and node periphery.
    +      // Note: This flatMap has side effects (on the model).
    +      activeNodePeriphery = 
LocalDecisionTreeUtils.computeActiveNodePeriphery(activeNodePeriphery,
    --- End diff --
    
    Just get the children from computeBestSplits.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to