Repository: spark Updated Branches: refs/heads/master 646f38346 -> 72d9fba26
[SPARK-17281][ML][MLLIB] Add treeAggregateDepth parameter for AFTSurvivalRegression ## What changes were proposed in this pull request? Add treeAggregateDepth parameter for AFTSurvivalRegression to keep consistent with LiR/LoR. ## How was this patch tested? Existing tests. Author: WeichenXu <weichenxu...@outlook.com> Closes #14851 from WeichenXu123/add_treeAggregate_param_for_survival_regression. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/72d9fba2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/72d9fba2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/72d9fba2 Branch: refs/heads/master Commit: 72d9fba26c19aae73116fd0d00b566967934c6fc Parents: 646f383 Author: WeichenXu <weichenxu...@outlook.com> Authored: Thu Sep 22 04:35:54 2016 -0700 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Thu Sep 22 04:35:54 2016 -0700 ---------------------------------------------------------------------- .../ml/regression/AFTSurvivalRegression.scala | 24 ++++++++++++++++---- python/pyspark/ml/regression.py | 11 +++++---- 2 files changed, 25 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/72d9fba2/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 3179f48..9d5ba99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -46,7 +46,7 @@ import org.apache.spark.storage.StorageLevel */ private[regression] trait AFTSurvivalRegressionParams extends Params with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter - with HasTol with HasFitIntercept with Logging { + with HasTol with HasFitIntercept with HasAggregationDepth with Logging { /** * Param for censor column name. @@ -184,6 +184,17 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S setDefault(tol -> 1E-6) /** + * Suggested depth for treeAggregate (>= 2). + * If the dimensions of features or the number of partitions are large, + * this param could be adjusted to a larger size. + * Default is 2. + * @group expertSetParam + */ + @Since("2.1.0") + def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) + setDefault(aggregationDepth -> 2) + + /** * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, * and put it in an RDD with strong types. */ @@ -207,7 +218,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => { c1.merge(c2) } - instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp) + instances.treeAggregate( + new MultivariateOnlineSummarizer + )(seqOp, combOp, $(aggregationDepth)) } val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) @@ -222,7 +235,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val bcFeaturesStd = instances.context.broadcast(featuresStd) - val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd) + val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd, $(aggregationDepth)) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) /* @@ -591,7 +604,8 @@ private class AFTAggregator( private class AFTCostFun( data: RDD[AFTPoint], fitIntercept: Boolean, - bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] { + bcFeaturesStd: Broadcast[Array[Double]], + aggregationDepth: Int) extends DiffFunction[BDV[Double]] { override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { @@ -604,7 +618,7 @@ private class AFTCostFun( }, combOp = (c1, c2) => (c1, c2) match { case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + }, depth = aggregationDepth) bcParameters.destroy(blocking = false) (aftAggregator.loss, aftAggregator.gradient) http://git-wip-us.apache.org/repos/asf/spark/blob/72d9fba2/python/pyspark/ml/regression.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 19afc72..55d3803 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1088,7 +1088,8 @@ class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable): + HasFitIntercept, HasMaxIter, HasTol, HasAggregationDepth, + JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -1153,12 +1154,12 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), - quantilesCol=None): + quantilesCol=None, aggregationDepth=2): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ - quantilesCol=None) + quantilesCol=None, aggregationDepth=2) """ super(AFTSurvivalRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -1174,12 +1175,12 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), - quantilesCol=None): + quantilesCol=None, aggregationDepth=2): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ - quantilesCol=None): + quantilesCol=None, aggregationDepth=2): """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org