Repository: spark
Updated Branches:
  refs/heads/master 3e770a64a -> e963070c1


[SPARK-9722] [ML] Pass random seed to spark.ml DecisionTree*

Author: Yu ISHIKAWA <yuu.ishik...@gmail.com>

Closes #9402 from yu-iskw/SPARK-9722.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e963070c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e963070c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e963070c

Branch: refs/heads/master
Commit: e963070c13f56fbc2dfaf9f5d4e69d34afd0957c
Parents: 3e770a6
Author: Yu ISHIKAWA <yuu.ishik...@gmail.com>
Authored: Sun Nov 1 23:52:50 2015 -0800
Committer: DB Tsai <d...@netflix.com>
Committed: Sun Nov 1 23:52:50 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/tree/impl/RandomForest.scala   | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e963070c/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 96d5652..4a3b12d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -74,7 +74,7 @@ private[ml] object RandomForest extends Logging {
     // Find the splits and the corresponding bins (interval between the 
splits) using a sample
     // of the input data.
     timer.start("findSplitsBins")
-    val splits = findSplits(retaggedInput, metadata)
+    val splits = findSplits(retaggedInput, metadata, seed)
     timer.stop("findSplitsBins")
     logDebug("numBins: feature: number of bins")
     logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
@@ -815,6 +815,7 @@ private[ml] object RandomForest extends Logging {
    *
    * @param input Training data: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]]
    * @param metadata Learning and dataset metadata
+   * @param seed random seed
    * @return A tuple of (splits, bins).
    *         Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
    *          of size (numFeatures, numSplits).
@@ -823,7 +824,8 @@ private[ml] object RandomForest extends Logging {
    */
   protected[tree] def findSplits(
       input: RDD[LabeledPoint],
-      metadata: DecisionTreeMetadata): Array[Array[Split]] = {
+      metadata: DecisionTreeMetadata,
+      seed : Long): Array[Array[Split]] = {
 
     logDebug("isMulticlass = " + metadata.isMulticlass)
 
@@ -840,7 +842,7 @@ private[ml] object RandomForest extends Logging {
         1.0
       }
       logDebug("fraction of data used for calculating quantiles = " + fraction)
-      input.sample(withReplacement = false, fraction, new 
XORShiftRandom(1).nextInt()).collect()
+      input.sample(withReplacement = false, fraction, new 
XORShiftRandom(seed).nextInt()).collect()
     } else {
       new Array[LabeledPoint](0)
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to