Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/14872#discussion_r76853370
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala ---
    @@ -0,0 +1,575 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tree.impl
    +
    +import org.roaringbitmap.RoaringBitmap
    +
    +import org.apache.spark.ml.tree._
    +import org.apache.spark.mllib.tree.model.ImpurityStats
    +import org.apache.spark.util.collection.BitSet
    +
    +/** Object exposing methods for local training of decision trees */
    +private[ml] object LocalDecisionTree {
    +
    +  /**
    +   * Fully splits the passed-in node on the provided local dataset.
    +   * TODO(smurching): Accept a seed for feature subsampling
    +   *
    +   * @param node LearningNode to split
    +   */
    +  def fitNode(
    +      input: Array[BaggedPoint[TreePoint]],
    +      node: LearningNode,
    +      metadata: DecisionTreeMetadata,
    +      splits: Array[Array[Split]]): Node = {
    +
    +    // The case with 1 node (depth = 0) is handled separately.
    +    // This allows all iterations in the depth > 0 case to use the same 
code.
    +    // TODO: Check that learning works when maxDepth > 0 but learning 
stops at 1 node (because of
    +    //       other parameters).
    +    if (metadata.maxDepth == 0) {
    +      val impurityAggregator: ImpurityAggregatorSingle =
    +        input.aggregate(metadata.createImpurityAggregator())(
    +          (agg, lp) => agg.update(lp.datum.label, 1.0),
    +          (agg1, agg2) => agg1.add(agg2))
    +      val impurityCalculator = impurityAggregator.getCalculator
    +      return new 
LeafNode(LocalDecisionTreeUtils.getPredict(impurityCalculator).predict,
    +        impurityCalculator.calculate(), impurityCalculator)
    +    }
    +
    +    // Prepare column store.
    +    //   Note: rowToColumnStoreDense checks to make sure numRows < 
Int.MaxValue.
    +    val colStoreInit: Array[Array[Int]]
    +      = 
LocalDecisionTreeUtils.rowToColumnStoreDense(input.map(_.datum.binnedFeatures))
    +    val labels = input.map(_.datum.label)
    +
    +    // Train classifier if numClasses is between 1 and 32, otherwise fit a 
regression model
    +    // on the dataset
    +    if (metadata.numClasses > 1 && metadata.numClasses <= 32) {
    +      throw new UnsupportedOperationException("Local training of a 
decision tree classifier is" +
    +        "unsupported; currently, only regression is supported")
    +    } else {
    +      // TODO(smurching): Pass an array of instanceWeights extracted from 
the input BaggedPoint?
    +      // Also, pass seed for feature subsampling
    +      trainRegressor(node, colStoreInit, labels, metadata, splits)
    +    }
    +  }
    +
    +  /**
    +   * Locally fits a decision tree regressor.
    +   * TODO(smurching): Logic for fitting a classifier & regressor seems the 
same; only difference
    +   * is impurity metric. Use the same logic for fitting a classifier?
    +   *
    +   * @param rootNode Node to fit on the passed-in dataset
    +   * @param colStoreInit Array of columns of training data
    +   * @param metadata Metadata object
    +   * @param splits splits(i) = Array of possible splits for feature i
    +   * @return
    +   */
    +  def trainRegressor(
    +      rootNode: LearningNode,
    +      colStoreInit: Array[Array[Int]],
    +      labels: Array[Double],
    +      metadata: DecisionTreeMetadata,
    +      splits: Array[Array[Split]]): Node = {
    +
    +    // Sort each column by feature values.
    +    val colStore: Array[FeatureVector] = colStoreInit.zipWithIndex.map { 
case (col, featureIndex) =>
    +      val featureArity: Int = 
metadata.featureArity.getOrElse(featureIndex, 0)
    +      FeatureVector.fromOriginal(featureIndex, featureArity, col)
    +    }
    +
    +    val numRows = colStore.headOption match {
    +      case None => 0
    +      case Some(column) => column.values.size
    +    }
    +
    +    // Create an impurityAggregator object containing info for 1 node (the 
root node).
    +    val fullImpurityAgg = metadata.createImpurityAggregator()
    +    labels.foreach(fullImpurityAgg.update(_))
    +
    +    // Create a bitset describing the set of active (non-leaf) nodes; 
initially, only the
    +    // root node is active
    +    val initActive = new BitSet(1)
    +    initActive.set(0)
    +    var partitionInfo: PartitionInfo = new PartitionInfo(colStore,
    +      Array[Int](0, numRows), initActive, Array(fullImpurityAgg))
    +
    +    // Initialize model.
    +    // Note: We do not use node indices.
    +    // Active nodes (still being split), updated each iteration
    +    var activeNodePeriphery: Array[LearningNode] = Array(rootNode)
    +    var numNodeOffsets: Int = 2
    +
    +    // Iteratively learn, one level of the tree at a time.
    +    var currentLevel = 0
    +    var doneLearning = false
    +    while (currentLevel < metadata.maxDepth && !doneLearning) {
    +      // Compute best split for each active node.
    +      val bestSplitsAndGains: Array[(Option[Split], ImpurityStats)] =
    +        computeBestSplits(partitionInfo, labels, metadata, splits)
    +      /*
    +      // NOTE: The actual active nodes (activeNodePeriphery) may be a 
subset of the nodes under
    +      //       bestSplitsAndGains since
    +      assert(activeNodePeriphery.length == bestSplitsAndGains.length,
    +        s"activeNodePeriphery.length=${activeNodePeriphery.length} does 
not equal" +
    +          s" bestSplitsAndGains.length=${bestSplitsAndGains.length}")
    +      */
    +
    +      // Update current model and node periphery.
    +      // Note: This flatMap has side effects (on the model).
    +      activeNodePeriphery = 
LocalDecisionTreeUtils.computeActiveNodePeriphery(activeNodePeriphery,
    +          bestSplitsAndGains, metadata.minInfoGain, 
metadata.minInstancesPerNode)
    +      // We keep all old nodeOffsets and add one for each node split.
    +      // Each node split adds 2 nodes to activeNodePeriphery.
    +      // TODO: Should this be calculated after filtering for impurity??
    +      numNodeOffsets = numNodeOffsets + activeNodePeriphery.length / 2
    +
    +      // Filter active node periphery by impurity.
    +      val estimatedRemainingActive = 
activeNodePeriphery.count(_.stats.impurity > 0.0)
    +
    +      // TODO: Check to make sure we split something, and stop otherwise.
    +      doneLearning = currentLevel + 1 >= metadata.maxDepth || 
estimatedRemainingActive == 0
    +
    +      if (!doneLearning) {
    +        val bestSplits: Array[Option[Split]] = bestSplitsAndGains.map(_._1)
    +
    +        // Aggregate bit vector (1 bit/instance) indicating whether each 
instance goes left/right
    +        val aggBitVector: RoaringBitmap = 
LocalDecisionTreeUtils.aggregateBitVector(partitionInfo,
    +          bestSplits, numRows, splits)
    +
    +        // Create a copy of our bitvector
    +        val bv = new BitSet(numRows)
    +        val iter = aggBitVector.getIntIterator
    +        while(iter.hasNext) {
    +          bv.set(iter.next)
    +        }
    +
    +        // Obtain a new partitionInfo instance describing our current 
training status; the offsets
    +        // of each node in our arrays of columns
    +        partitionInfo = partitionInfo.update(bv, numNodeOffsets, labels, 
metadata)
    +      }
    +      currentLevel += 1
    +    }
    +
    +    // Done with learning
    +    rootNode.toNode
    +  }
    +
    +  /**
    +   * Find the best splits for all active nodes.
    +   *  - For each feature, select the best split for each node.
    +   *
    +   * @return  Array over active nodes of (best split, impurity stats for 
split),
    +   *          where the split is None if no useful split exists
    +   */
    +  private[impl] def computeBestSplits(
    +      partitionInfo: PartitionInfo,
    +      labels: Array[Double],
    +      metadata: DecisionTreeMetadata,
    +      splits: Array[Array[Split]]): Array[(Option[Split], ImpurityStats)] 
= {
    +    // For each feature, select the best split for each node.
    +    // This will use:
    +    //  - labels (the labels column)
    +    // Returns:
    +    //   for each active node, best split + info gain,
    +    //     where the best split is None if no useful split exists
    +
    +    partitionInfo match {
    +      case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: 
Array[Int],
    +        activeNodes: BitSet, fullImpurityAggs: 
Array[ImpurityAggregatorSingle]) => {
    +
    +        // Iterate over the active nodes in the current level.
    +        val toReturn = new Array[(Option[Split], 
ImpurityStats)](activeNodes.cardinality())
    +        val iter: Iterator[Int] = activeNodes.iterator
    +        var i = 0
    +        while (iter.hasNext) {
    +          // Our iterator iterates left-to-right across the active node 
periphery
    +          val nodeIndexInLevel = iter.next
    +          // Features for the current node start at fromOffset and end at 
toOffset
    +          val fromOffset = nodeOffsets(nodeIndexInLevel)
    +          val toOffset = nodeOffsets(nodeIndexInLevel + 1)
    +          // Get the impurity aggregator for the current node
    +          val fullImpurityAgg = fullImpurityAggs(nodeIndexInLevel)
    +          // Get the best split for each feature for the current node
    +
    +          // TODO(smurching): In PartitionInfo, keep track of which 
features are associated
    +          // with which nodes and subsample here
    +          val splitsAndStats =
    +            columns.map { col =>
    --- End diff --
    
    Maybe pass in the node here, and actually split it, returning the new child 
nodes.  No need to return stats since they are stored in the split node 
already.  Probably.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to