Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/14872#discussion_r76853423
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeUtils.scala 
---
    @@ -0,0 +1,259 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tree.impl
    +
    +import org.roaringbitmap.RoaringBitmap
    +
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.ml.classification.DecisionTreeClassificationModel
    +import org.apache.spark.ml.regression.DecisionTreeRegressionModel
    +import org.apache.spark.ml.tree._
    +import org.apache.spark.mllib.linalg.Vector
    +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
    +import org.apache.spark.mllib.tree.impurity._
    +import org.apache.spark.mllib.tree.model.{ImpurityStats, Predict}
    +import org.apache.spark.util.collection.BitSet
    +
    +/**
    + * DecisionTree which partitions data by feature.
    + *
    + * Algorithm:
    + *  - Repartition data, grouping by feature.
    + *  - Prep data (sort continuous features).
    + *  - On each partition, initialize instance--node map with each instance 
at root node.
    + *  - Iterate, training 1 new level of the tree at a time:
    + *     - On each partition, for each feature on the partition, select the 
best split for each node.
    + *     - Aggregate best split for each node.
    + *     - Aggregate bit vector (1 bit/instance) indicating whether each 
instance splits
    + *       left or right.
    + *     - Broadcast bit vector.  On each partition, update instance--node 
map.
    + *
    + * TODO: Update to use a sparse column store.
    + */
    +private[ml] object LocalDecisionTreeUtils extends Logging {
    +
    +  /**
    +   * Convert a dataset of [[Vector]] from row storage to column storage.
    +   * This can take any [[Vector]] type but stores data as [[DenseVector]].
    +   *
    +   * This maintains sparsity in the data.
    +   *
    +   * This maintains matrix structure.  I.e., each partition of the output 
RDD holds adjacent
    +   * columns.  The number of partitions will be min(input RDD's number of 
partitions, numColumns).
    +   *
    +   * @param rowStore  An array of input data rows, each represented as an
    +   *                  int array of binned feature values
    +   * @return Transpose of rowStore with
    +   *
    +   * TODO: Add implementation for sparse data.
    +   *       For sparse data, distribute more evenly based on number of 
non-zeros.
    +   *       (First collect stats to decide how to partition.)
    +   */
    +  private[impl] def rowToColumnStoreDense(rowStore: Array[Array[Int]]): 
Array[Array[Int]] = {
    +    // Compute the number of rows in the data
    +    val numRows = {
    +      val longNumRows: Long = rowStore.size
    +      require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD 
with $longNumRows rows," +
    +        s" but can handle at most ${Int.MaxValue} rows")
    +      longNumRows.toInt
    +    }
    +
    +    // Return an empty array for a dataset with zero rows or columns, 
otherwise
    +    // return the transpose of the rowStore matrix
    +    if (numRows == 0 || rowStore(0).size == 0) {
    +      Array.empty
    +    } else {
    +      val numCols = rowStore(0).size
    +      0.until(numCols).map { colIdx =>
    +        rowStore.map(row => row(colIdx))
    +      }.toArray
    +    }
    +  }
    +
    +  private[impl] def finalizeTree(
    +      rootNode: Node,
    +      algo: OldAlgo.Algo,
    +      numClasses: Int,
    +      numFeatures: Int,
    +      parentUID: Option[String]): DecisionTreeModel = {
    +    parentUID match {
    +      case Some(uid) =>
    +        if (algo == OldAlgo.Classification) {
    +          new DecisionTreeClassificationModel(uid, rootNode, numFeatures = 
numFeatures,
    +            numClasses = numClasses)
    +        } else {
    +          new DecisionTreeRegressionModel(uid, rootNode, numFeatures = 
numFeatures)
    +        }
    +      case None =>
    +        if (algo == OldAlgo.Classification) {
    +          new DecisionTreeClassificationModel(rootNode, numFeatures = 
numFeatures,
    +            numClasses = numClasses)
    +        } else {
    +          new DecisionTreeRegressionModel(rootNode, numFeatures = 
numFeatures)
    +        }
    +    }
    +  }
    +
    +  private[impl] def getPredict(impurityCalculator: ImpurityCalculator): 
Predict = {
    +    val pred = impurityCalculator.predict
    +    new Predict(predict = pred, prob = impurityCalculator.prob(pred))
    +  }
    +
    +  /**
    +   * On driver: Grow tree based on chosen splits, and compute new set of 
active nodes.
    +   *
    +   * @param oldPeriphery  Old periphery of active nodes.
    +   * @param bestSplitsAndGains  Best (split, gain) pairs, which can be 
zipped with the old
    +   *                            periphery.  These stats will be used to 
replace the stats in
    +   *                            any nodes which are split.
    +   * @param minInfoGain  Threshold for min info gain required to split a 
node.
    +   * @return  New active node periphery.
    +   *          If a node is split, then this method will update its fields.
    +   */
    +  private[impl] def computeActiveNodePeriphery(
    --- End diff --
    
    Can this be merged into the other logic?  Here, driver/worker logic is not 
separate, so they should be merged.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to