Repository: spark
Updated Branches:
  refs/heads/master aad644fbe -> 64743870f


[SPARK-10394] [ML] Make GBTParams use shared stepSize

```GBTParams``` has ```stepSize``` as learning rate currently.
ML has shared param class ```HasStepSize```, ```GBTParams``` can extend from it 
rather than duplicated implementation.

Author: Yanbo Liang <[email protected]>

Closes #8552 from yanboliang/spark-10394.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/64743870
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/64743870
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/64743870

Branch: refs/heads/master
Commit: 64743870f23bffb8d96dcc8a0181c1452782a151
Parents: aad644f
Author: Yanbo Liang <[email protected]>
Authored: Thu Sep 17 11:24:38 2015 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Thu Sep 17 11:24:38 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/tree/treeParams.scala   | 28 +++++++++-----------
 1 file changed, 13 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/64743870/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index d29f525..42e74ce 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.tree
 import org.apache.spark.ml.classification.ClassifierParams
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, 
HasSeed, HasThresholds}
+import org.apache.spark.ml.param.shared._
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, 
BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => 
OldGini, Impurity => OldImpurity, Variance => OldVariance}
 import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
@@ -365,17 +365,7 @@ private[ml] object RandomForestParams {
  *
  * Note: Marked as private and DeveloperApi since this may be made public in 
the future.
  */
-private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
-
-  /**
-   * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the 
contribution of each
-   * estimator.
-   * (default = 0.1)
-   * @group param
-   */
-  final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step 
size (a.k.a." +
-    " learning rate) in interval (0, 1] for shrinking the contribution of each 
estimator",
-    ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = 
true))
+private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with 
HasStepSize {
 
   /* TODO: Add this doc when we add this param.  SPARK-7132
    * Threshold for stopping early when runWithValidation is used.
@@ -393,11 +383,19 @@ private[ml] trait GBTParams extends TreeEnsembleParams 
with HasMaxIter {
   /** @group setParam */
   def setMaxIter(value: Int): this.type = set(maxIter, value)
 
-  /** @group setParam */
+  /**
+   * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the 
contribution of each
+   * estimator.
+   * (default = 0.1)
+   * @group setParam
+   */
   def setStepSize(value: Double): this.type = set(stepSize, value)
 
-  /** @group getParam */
-  final def getStepSize: Double = $(stepSize)
+  override def validateParams(): Unit = {
+    require(ParamValidators.inRange(0, 1, lowerInclusive = false, 
upperInclusive = true)(
+      getStepSize), "GBT parameter stepSize should be in interval (0, 1], " +
+      s"but it given invalid value $getStepSize.")
+  }
 
   /** (private[ml]) Create a BoostingStrategy instance to use with the old 
API. */
   private[ml] def getOldBoostingStrategy(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to