Repository: spark
Updated Branches:
  refs/heads/master 646f38346 -> 72d9fba26


[SPARK-17281][ML][MLLIB] Add treeAggregateDepth parameter for 
AFTSurvivalRegression

## What changes were proposed in this pull request?

Add treeAggregateDepth parameter for AFTSurvivalRegression to keep consistent 
with LiR/LoR.

## How was this patch tested?

Existing tests.

Author: WeichenXu <weichenxu...@outlook.com>

Closes #14851 from WeichenXu123/add_treeAggregate_param_for_survival_regression.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/72d9fba2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/72d9fba2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/72d9fba2

Branch: refs/heads/master
Commit: 72d9fba26c19aae73116fd0d00b566967934c6fc
Parents: 646f383
Author: WeichenXu <weichenxu...@outlook.com>
Authored: Thu Sep 22 04:35:54 2016 -0700
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Thu Sep 22 04:35:54 2016 -0700

----------------------------------------------------------------------
 .../ml/regression/AFTSurvivalRegression.scala   | 24 ++++++++++++++++----
 python/pyspark/ml/regression.py                 | 11 +++++----
 2 files changed, 25 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/72d9fba2/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 3179f48..9d5ba99 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -46,7 +46,7 @@ import org.apache.spark.storage.StorageLevel
  */
 private[regression] trait AFTSurvivalRegressionParams extends Params
   with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
-  with HasTol with HasFitIntercept with Logging {
+  with HasTol with HasFitIntercept with HasAggregationDepth with Logging {
 
   /**
    * Param for censor column name.
@@ -184,6 +184,17 @@ class AFTSurvivalRegression @Since("1.6.0") 
(@Since("1.6.0") override val uid: S
   setDefault(tol -> 1E-6)
 
   /**
+   * Suggested depth for treeAggregate (>= 2).
+   * If the dimensions of features or the number of partitions are large,
+   * this param could be adjusted to a larger size.
+   * Default is 2.
+   * @group expertSetParam
+   */
+  @Since("2.1.0")
+  def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
+  setDefault(aggregationDepth -> 2)
+
+  /**
    * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input 
dataset,
    * and put it in an RDD with strong types.
    */
@@ -207,7 +218,9 @@ class AFTSurvivalRegression @Since("1.6.0") 
(@Since("1.6.0") override val uid: S
       val combOp = (c1: MultivariateOnlineSummarizer, c2: 
MultivariateOnlineSummarizer) => {
         c1.merge(c2)
       }
-      instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp)
+      instances.treeAggregate(
+        new MultivariateOnlineSummarizer
+      )(seqOp, combOp, $(aggregationDepth))
     }
 
     val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
@@ -222,7 +235,7 @@ class AFTSurvivalRegression @Since("1.6.0") 
(@Since("1.6.0") override val uid: S
 
     val bcFeaturesStd = instances.context.broadcast(featuresStd)
 
-    val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd)
+    val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd, 
$(aggregationDepth))
     val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
 
     /*
@@ -591,7 +604,8 @@ private class AFTAggregator(
 private class AFTCostFun(
     data: RDD[AFTPoint],
     fitIntercept: Boolean,
-    bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] 
{
+    bcFeaturesStd: Broadcast[Array[Double]],
+    aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
 
   override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
 
@@ -604,7 +618,7 @@ private class AFTCostFun(
       },
       combOp = (c1, c2) => (c1, c2) match {
         case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
-      })
+      }, depth = aggregationDepth)
 
     bcParameters.destroy(blocking = false)
     (aftAggregator.loss, aftAggregator.gradient)

http://git-wip-us.apache.org/repos/asf/spark/blob/72d9fba2/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 19afc72..55d3803 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -1088,7 +1088,8 @@ class GBTRegressionModel(TreeEnsembleModel, 
JavaPredictionModel, JavaMLWritable,
 
 @inherit_doc
 class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, 
HasPredictionCol,
-                            HasFitIntercept, HasMaxIter, HasTol, 
JavaMLWritable, JavaMLReadable):
+                            HasFitIntercept, HasMaxIter, HasTol, 
HasAggregationDepth,
+                            JavaMLWritable, JavaMLReadable):
     """
     .. note:: Experimental
 
@@ -1153,12 +1154,12 @@ class AFTSurvivalRegression(JavaEstimator, 
HasFeaturesCol, HasLabelCol, HasPredi
     def __init__(self, featuresCol="features", labelCol="label", 
predictionCol="prediction",
                  fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
                  quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 
0.9, 0.95, 0.99]),
-                 quantilesCol=None):
+                 quantilesCol=None, aggregationDepth=2):
         """
         __init__(self, featuresCol="features", labelCol="label", 
predictionCol="prediction", \
                  fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", 
\
                  quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 
0.95, 0.99], \
-                 quantilesCol=None)
+                 quantilesCol=None, aggregationDepth=2)
         """
         super(AFTSurvivalRegression, self).__init__()
         self._java_obj = self._new_java_obj(
@@ -1174,12 +1175,12 @@ class AFTSurvivalRegression(JavaEstimator, 
HasFeaturesCol, HasLabelCol, HasPredi
     def setParams(self, featuresCol="features", labelCol="label", 
predictionCol="prediction",
                   fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
                   quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 
0.75, 0.9, 0.95, 0.99]),
-                  quantilesCol=None):
+                  quantilesCol=None, aggregationDepth=2):
         """
         setParams(self, featuresCol="features", labelCol="label", 
predictionCol="prediction", \
                   fitIntercept=True, maxIter=100, tol=1E-6, 
censorCol="censor", \
                   quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 
0.9, 0.95, 0.99], \
-                  quantilesCol=None):
+                  quantilesCol=None, aggregationDepth=2):
         """
         kwargs = self.setParams._input_kwargs
         return self._set(**kwargs)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to