Repository: spark
Updated Branches:
  refs/heads/branch-1.6 83b957a35 -> de03021ac


[SPARK-11514][ML] Pass random seed to spark.ml DecisionTree*

cc jkbradley

Author: Yu ISHIKAWA <yuu.ishik...@gmail.com>

Closes #9486 from yu-iskw/SPARK-11514.

(cherry picked from commit 8fa8c8375d7015a0332aa9ee613d7c6b6d62bae7)
Signed-off-by: Joseph K. Bradley <jos...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/de03021a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/de03021a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/de03021a

Branch: refs/heads/branch-1.6
Commit: de03021ac9058e56481c227028cdd4635c803713
Parents: 83b957a
Author: Yu ISHIKAWA <yuu.ishik...@gmail.com>
Authored: Thu Nov 5 17:59:01 2015 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Thu Nov 5 17:59:11 2015 -0800

----------------------------------------------------------------------
 .../spark/ml/classification/DecisionTreeClassifier.scala |  4 +++-
 .../spark/ml/regression/DecisionTreeRegressor.scala      |  4 +++-
 .../main/scala/org/apache/spark/ml/tree/treeParams.scala | 11 ++++++-----
 .../ml/classification/DecisionTreeClassifierSuite.scala  |  1 +
 .../spark/ml/regression/DecisionTreeRegressorSuite.scala |  1 +
 5 files changed, 14 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/de03021a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index b0157f7..c478aea 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -62,6 +62,8 @@ final class DecisionTreeClassifier(override val uid: String)
 
   override def setImpurity(value: String): this.type = super.setImpurity(value)
 
+  override def setSeed(value: Long): this.type = super.setSeed(value)
+
   override protected def train(dataset: DataFrame): 
DecisionTreeClassificationModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -75,7 +77,7 @@ final class DecisionTreeClassifier(override val uid: String)
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
     val strategy = getOldStrategy(categoricalFeatures, numClasses)
     val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
-      seed = 0L, parentUID = Some(uid))
+      seed = $(seed), parentUID = Some(uid))
     trees.head.asInstanceOf[DecisionTreeClassificationModel]
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/de03021a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 04420fc..477030d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -71,13 +71,15 @@ final class DecisionTreeRegressor @Since("1.4.0") 
(@Since("1.4.0") override val
   @Since("1.4.0")
   override def setImpurity(value: String): this.type = super.setImpurity(value)
 
+  override def setSeed(value: Long): this.type = super.setSeed(value)
+
   override protected def train(dataset: DataFrame): 
DecisionTreeRegressionModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
     val strategy = getOldStrategy(categoricalFeatures)
     val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
-      seed = 0L, parentUID = Some(uid))
+      seed = $(seed), parentUID = Some(uid))
     trees.head.asInstanceOf[DecisionTreeRegressionModel]
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/de03021a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 281ba6e..1da97db 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -29,7 +29,8 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
  *
  * Note: Marked as private and DeveloperApi since this may be made public in 
the future.
  */
-private[ml] trait DecisionTreeParams extends PredictorParams with 
HasCheckpointInterval {
+private[ml] trait DecisionTreeParams extends PredictorParams
+  with HasCheckpointInterval with HasSeed {
 
   /**
    * Maximum depth of the tree (>= 0).
@@ -123,6 +124,9 @@ private[ml] trait DecisionTreeParams extends 
PredictorParams with HasCheckpointI
   /** @group getParam */
   final def getMinInfoGain: Double = $(minInfoGain)
 
+  /** @group setParam */
+  def setSeed(value: Long): this.type = set(seed, value)
+
   /** @group expertSetParam */
   def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
 
@@ -257,7 +261,7 @@ private[ml] object TreeRegressorParams {
  *
  * Note: Marked as private and DeveloperApi since this may be made public in 
the future.
  */
-private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
+private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
 
   /**
    * Fraction of the training data used for learning each decision tree, in 
range (0, 1].
@@ -276,9 +280,6 @@ private[ml] trait TreeEnsembleParams extends 
DecisionTreeParams with HasSeed {
   /** @group getParam */
   final def getSubsamplingRate: Double = $(subsamplingRate)
 
-  /** @group setParam */
-  def setSeed(value: Long): this.type = set(seed, value)
-
   /**
    * Create a Strategy instance to use with the old API.
    * NOTE: The caller should set impurity and seed.

http://git-wip-us.apache.org/repos/asf/spark/blob/de03021a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 815f6fd..92b8f84 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -72,6 +72,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with 
MLlibTestSparkConte
       .setImpurity("gini")
       .setMaxDepth(2)
       .setMaxBins(100)
+      .setSeed(1)
     val categoricalFeatures = Map(0 -> 3, 1-> 3)
     val numClasses = 2
     compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses)

http://git-wip-us.apache.org/repos/asf/spark/blob/de03021a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 868fb8e..e0d5afa 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -49,6 +49,7 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with 
MLlibTestSparkContex
       .setImpurity("variance")
       .setMaxDepth(2)
       .setMaxBins(100)
+      .setSeed(1)
     val categoricalFeatures = Map(0 -> 3, 1-> 3)
     compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to