Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19433#discussion_r150160368
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala ---
    @@ -0,0 +1,215 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.tree.impl
    +
    +import org.apache.spark.ml.tree.{CategoricalSplit, Split}
    +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
    +import org.apache.spark.mllib.tree.model.ImpurityStats
    +
    +/** Utility methods for choosing splits during local & distributed tree 
training. */
    +private[impl] object SplitUtils {
    +
    +  /** Sorts ordered feature categories by label centroid, returning an 
ordered list of categories */
    +  private def sortByCentroid(
    +      binAggregates: DTStatsAggregator,
    +      featureIndex: Int,
    +      featureIndexIdx: Int): List[Int] = {
    +    /* Each bin is one category (feature value).
    +     * The bins are ordered based on centroidForCategories, and this 
ordering determines which
    +     * splits are considered.  (With K categories, we consider K - 1 
possible splits.)
    +     *
    +     * centroidForCategories is a list: (category, centroid)
    +     */
    +    val numCategories = binAggregates.metadata.numBins(featureIndex)
    +    val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
    +
    +    val centroidForCategories = Range(0, numCategories).map { featureValue 
=>
    +      val categoryStats =
    +        binAggregates.getImpurityCalculator(nodeFeatureOffset, 
featureValue)
    +      val centroid = ImpurityUtils.getCentroid(binAggregates.metadata, 
categoryStats)
    +      (featureValue, centroid)
    +    }
    +    // TODO(smurching): How to handle logging statements like these?
    +    // logDebug("Centroids for categorical variable: " + 
centroidForCategories.mkString(","))
    +    // bins sorted by centroids
    +    val categoriesSortedByCentroid = 
centroidForCategories.toList.sortBy(_._2).map(_._1)
    +    // logDebug("Sorted centroids for categorical variable = " +
    +    //   categoriesSortedByCentroid.mkString(","))
    +    categoriesSortedByCentroid
    +  }
    +
    +  /**
    +   * Find the best split for an unordered categorical feature at a single 
node.
    +   *
    +   * Algorithm:
    +   *  - Considers all possible subsets (exponentially many)
    +   *
    +   * @param featureIndex  Global index of feature being split.
    +   * @param featureIndexIdx Index of feature being split within subset of 
features for current node.
    +   * @param featureSplits Array of splits for the current feature
    +   * @param parentCalculator Optional: ImpurityCalculator containing 
impurity stats for current node
    +   * @return  (best split, statistics for split)  If no valid split was 
found, the returned
    +   *          ImpurityStats instance will be invalid (have member valid = 
false).
    +   */
    +  private[impl] def chooseUnorderedCategoricalSplit(
    +      binAggregates: DTStatsAggregator,
    +      featureIndex: Int,
    +      featureIndexIdx: Int,
    +      featureSplits: Array[Split],
    +      parentCalculator: Option[ImpurityCalculator] = None): (Split, 
ImpurityStats) = {
    +    // Unordered categorical feature
    +    val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
    +    val numSplits = binAggregates.metadata.numSplits(featureIndex)
    +    var parentCalc = parentCalculator
    --- End diff --
    
    It'd be nice to calculate the parentCalc right away here, if needed.  That 
seems possible just by taking the first candidate split.  Then we could 
simplify calculateImpurityStats by not passing in parentCalc as an option.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to