Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/3374#discussion_r20629859
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala ---
    @@ -40,151 +39,98 @@ import org.apache.spark.storage.StorageLevel
      * Notes:
      *  - This currently can be run with several loss functions.  However, 
only SquaredError is
      *    fully supported.  Specifically, the loss function should be used to 
compute the gradient
    - *    (to re-label training instances on each iteration) and to weight 
weak hypotheses.
    + *    (to re-label training instances on each iteration) and to weight 
tree ensembles.
      *    Currently, gradients are computed correctly for the available loss 
functions,
    - *    but weak hypothesis weights are not computed correctly for LogLoss 
or AbsoluteError.
    - *    Running with those losses will likely behave reasonably, but lacks 
the same guarantees.
    + *    but tree predictions are not computed correctly for LogLoss or 
AbsoluteError since they
    + *    use the mean of the samples at each leaf node.  Running with those 
losses will likely behave
    + *    reasonably, but lacks the same guarantees.
      *
    - * @param boostingStrategy Parameters for the gradient boosting algorithm
    + * @param boostingStrategy Parameters for the gradient boosting algorithm.
      */
     @Experimental
    -class GradientBoosting (
    -    private val boostingStrategy: BoostingStrategy) extends Serializable 
with Logging {
    -
    -  boostingStrategy.weakLearnerParams.algo = Regression
    -  boostingStrategy.weakLearnerParams.impurity = impurity.Variance
    -
    -  // Ensure values for weak learner are the same as what is provided to 
the boosting algorithm.
    -  boostingStrategy.weakLearnerParams.numClassesForClassification =
    -    boostingStrategy.numClassesForClassification
    -
    -  boostingStrategy.assertValid()
    +class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
    +  extends Serializable with Logging {
     
       /**
        * Method to train a gradient boosting model
        * @param input Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
    -   * @return WeightedEnsembleModel that can be used for prediction
    +   * @return a gradient boosted trees model that can be used for prediction
        */
    -  def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
    -    val algo = boostingStrategy.algo
    +  def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
    +    val algo = boostingStrategy.treeStrategy.algo
         algo match {
    -      case Regression => GradientBoosting.boost(input, boostingStrategy)
    +      case Regression => GradientBoostedTrees.boost(input, 
boostingStrategy)
           case Classification =>
             // Map labels to -1, +1 so binary classification can be treated as 
regression.
             val remappedInput = input.map(x => new LabeledPoint((x.label * 2) 
- 1, x.features))
    -        GradientBoosting.boost(remappedInput, boostingStrategy)
    +        GradientBoostedTrees.boost(remappedInput, boostingStrategy)
           case _ =>
             throw new IllegalArgumentException(s"$algo is not supported by the 
gradient boosting.")
         }
       }
     
    +  /**
    +   * Java-friendly API for 
[[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]].
    +   */
    +  def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
    +    run(input.rdd)
    +  }
     }
     
     
    -object GradientBoosting extends Logging {
    +object GradientBoostedTrees extends Logging {
     
       /**
        * Method to train a gradient boosting model.
        *
    -   * Note: Using 
[[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
    -   *       is recommended to clearly specify regression.
    -   *       Using 
[[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
    -   *       is recommended to clearly specify regression.
    -   *
        * @param input Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
        *              For classification, labels should take values {0, 1, 
..., numClasses-1}.
        *              For regression, labels are real numbers.
        * @param boostingStrategy Configuration options for the boosting 
algorithm.
    -   * @return WeightedEnsembleModel that can be used for prediction
    +   * @return a gradient boosted trees model that can be used for prediction
        */
       def train(
           input: RDD[LabeledPoint],
    -      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
    -    new GradientBoosting(boostingStrategy).train(input)
    +      boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
    +    new GradientBoostedTrees(boostingStrategy).run(input)
       }
     
       /**
    -   * Method to train a gradient boosting classification model.
    -   *
    -   * @param input Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
    -   *              For classification, labels should take values {0, 1, 
..., numClasses-1}.
    -   *              For regression, labels are real numbers.
    -   * @param boostingStrategy Configuration options for the boosting 
algorithm.
    -   * @return WeightedEnsembleModel that can be used for prediction
    -   */
    -  def trainClassifier(
    -      input: RDD[LabeledPoint],
    -      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
    -    val algo = boostingStrategy.algo
    -    require(algo == Classification, s"Only Classification algo supported. 
Provided algo is $algo.")
    -    new GradientBoosting(boostingStrategy).train(input)
    -  }
    -
    -  /**
    -   * Method to train a gradient boosting regression model.
    -   *
    -   * @param input Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
    -   *              For classification, labels should take values {0, 1, 
..., numClasses-1}.
    -   *              For regression, labels are real numbers.
    -   * @param boostingStrategy Configuration options for the boosting 
algorithm.
    -   * @return WeightedEnsembleModel that can be used for prediction
    -   */
    -  def trainRegressor(
    -      input: RDD[LabeledPoint],
    -      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
    -    val algo = boostingStrategy.algo
    -    require(algo == Regression, s"Only Regression algo supported. Provided 
algo is $algo.")
    -    new GradientBoosting(boostingStrategy).train(input)
    -  }
    -
    -  /**
    -   * Java-friendly API for 
[[org.apache.spark.mllib.tree.GradientBoosting$#train]]
    +   * Java-friendly API for 
[[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]]
        */
       def train(
    -    input: JavaRDD[LabeledPoint],
    -    boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
    -    train(input.rdd, boostingStrategy)
    -  }
    -
    -  /**
    -   * Java-friendly API for 
[[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
    -   */
    -  def trainClassifier(
    -      input: JavaRDD[LabeledPoint],
    -      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
    -    trainClassifier(input.rdd, boostingStrategy)
    -  }
    -
    -  /**
    -   * Java-friendly API for 
[[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
    -   */
    -  def trainRegressor(
           input: JavaRDD[LabeledPoint],
    -      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
    -    trainRegressor(input.rdd, boostingStrategy)
    +      boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
    +    train(input.rdd, boostingStrategy)
       }
     
       /**
        * Internal method for performing regression using trees as base 
learners.
        * @param input training dataset
        * @param boostingStrategy boosting parameters
    -   * @return
    +   * @return a gradient boosted trees model that can be used for prediction
        */
       private def boost(
           input: RDD[LabeledPoint],
    -      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
    +      boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
     
         val timer = new TimeTracker()
         timer.start("total")
         timer.start("init")
     
    +    boostingStrategy.assertValid()
    +
         // Initialize gradient boosting parameters
         val numIterations = boostingStrategy.numIterations
         val baseLearners = new Array[DecisionTreeModel](numIterations)
         val baseLearnerWeights = new Array[Double](numIterations)
         val loss = boostingStrategy.loss
         val learningRate = boostingStrategy.learningRate
    -    val strategy = boostingStrategy.weakLearnerParams
    +    // Prepare strategy for tree ensembles. Tree ensembles use regression 
with variance impurity.
    --- End diff --
    
    done


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to