Repository: spark
Updated Branches:
  refs/heads/master 124cbfb68 -> da60b34d2


[SPARK-3724][ML] RandomForest: More options for feature subset size.

## What changes were proposed in this pull request?

This PR tries to support more options for feature subset size in RandomForest 
implementation. Previously, RandomForest only support "auto", "all", "sort", 
"log2", "onethird". This PR tries to support any given value to allow model 
search.

In this PR, `featureSubsetStrategy` could be passed with:
a) a real number in the range of `(0.0-1.0]` that represents the fraction of 
the number of features in each subset,
b)  an integer number (`>0`) that represents the number of features in each 
subset.

## How was this patch tested?

Two tests `JavaRandomForestClassifierSuite` and 
`JavaRandomForestRegressorSuite` have been updated to check the additional 
options for params in this PR.
An additional test has been added to 
`org.apache.spark.mllib.tree.RandomForestSuite` to cover the cases in this PR.

Author: Yong Tang <yong.tang.git...@outlook.com>

Closes #11989 from yongtang/SPARK-3724.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/da60b34d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/da60b34d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/da60b34d

Branch: refs/heads/master
Commit: da60b34d2f6eba19633e4f1b46504ce92cd6c179
Parents: 124cbfb
Author: Yong Tang <yong.tang.git...@outlook.com>
Authored: Tue Apr 12 16:53:26 2016 +0200
Committer: Nick Pentreath <nick.pentre...@gmail.com>
Committed: Tue Apr 12 16:53:26 2016 +0200

----------------------------------------------------------------------
 .../ml/tree/impl/DecisionTreeMetadata.scala     |  5 +++
 .../org/apache/spark/ml/tree/treeParams.scala   |  8 ++++-
 .../apache/spark/mllib/tree/RandomForest.scala  | 11 ++++--
 .../JavaRandomForestClassifierSuite.java        | 19 +++++++++++
 .../JavaRandomForestRegressorSuite.java         | 19 +++++++++++
 .../spark/ml/tree/impl/RandomForestSuite.scala  | 36 ++++++++++++++++++++
 6 files changed, 95 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/da60b34d/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index df8eb5d..c7cde15 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -183,11 +183,16 @@ private[spark] object DecisionTreeMetadata extends 
Logging {
         }
       case _ => featureSubsetStrategy
     }
+
+    val isIntRegex = "^([1-9]\\d*)$".r
+    val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r
     val numFeaturesPerNode: Int = _featureSubsetStrategy match {
       case "all" => numFeatures
       case "sqrt" => math.sqrt(numFeatures).ceil.toInt
       case "log2" => math.max(1, (math.log(numFeatures) / 
math.log(2)).ceil.toInt)
       case "onethird" => (numFeatures / 3.0).ceil.toInt
+      case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures 
else number.toInt
+      case isFractionRegex(fraction) => (fraction.toDouble * 
numFeatures).ceil.toInt
     }
 
     new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,

http://git-wip-us.apache.org/repos/asf/spark/blob/da60b34d/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 78e6d3b..0767dc1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -329,6 +329,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
    *  - "onethird": use 1/3 of the features
    *  - "sqrt": use sqrt(number of features)
    *  - "log2": use log2(number of features)
+   *  - "n": when n is in the range (0, 1.0], use n * number of features. When 
n
+   *         is in the range (1, number of features), use n features.
    * (default = "auto")
    *
    * These various settings are based on the following references:
@@ -346,7 +348,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
     "The number of features to consider for splits at each tree node." +
       s" Supported options: 
${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
     (value: String) =>
-      
RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
+      
RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)
+      || 
value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex))
 
   setDefault(featureSubsetStrategy -> "auto")
 
@@ -393,6 +396,9 @@ private[spark] object RandomForestParams {
   // These options should be lowercase.
   final val supportedFeatureSubsetStrategies: Array[String] =
     Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
+
+  // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of 
features)
+  final val supportedFeatureSubsetStrategiesRegex = 
"^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$"
 }
 
 private[ml] trait RandomForestClassifierParams

http://git-wip-us.apache.org/repos/asf/spark/blob/da60b34d/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 1841fa4..2675584 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -55,10 +55,15 @@ import org.apache.spark.util.Utils
  * @param numTrees If 1, then no bootstrapping is used.  If > 1, then 
bootstrapping is done.
  * @param featureSubsetStrategy Number of features to consider for splits at 
each node.
  *                              Supported values: "auto", "all", "sqrt", 
"log2", "onethird".
+ *                              Supported numerical values: "(0.0-1.0]", 
"[1-n]".
  *                              If "auto" is set, this parameter is set based 
on numTrees:
  *                                if numTrees == 1, set to "all";
  *                                if numTrees > 1 (forest) set to "sqrt" for 
classification and
  *                                  to "onethird" for regression.
+ *                              If a real value "n" in the range (0, 1.0] is 
set,
+ *                                use n * number of features.
+ *                              If an integer value "n" in the range (1, num 
features) is set,
+ *                                use n features.
  * @param seed Random seed for bootstrapping and choosing feature subsets.
  */
 private class RandomForest (
@@ -70,9 +75,11 @@ private class RandomForest (
 
   strategy.assertValid()
   require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given 
numTrees = $numTrees.")
-  
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
+  
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
+    || 
featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex),
     s"RandomForest given invalid featureSubsetStrategy: 
$featureSubsetStrategy." +
-    s" Supported values: 
${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.")
+    s" Supported values: 
${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," +
+    s" (0.0-1.0], [1-n].")
 
   /**
    * Method to train a decision tree model over an RDD

http://git-wip-us.apache.org/repos/asf/spark/blob/da60b34d/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 7506146..5aec52a 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.Map;
 
 import org.junit.After;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -80,6 +81,24 @@ public class JavaRandomForestClassifierSuite implements 
Serializable {
     for (String featureSubsetStrategy: 
RandomForestClassifier.supportedFeatureSubsetStrategies()) {
       rf.setFeatureSubsetStrategy(featureSubsetStrategy);
     }
+    String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
+    for (String strategy: realStrategies) {
+      rf.setFeatureSubsetStrategy(strategy);
+    }
+    String integerStrategies[] = {"1", "10", "100", "1000", "10000"};
+    for (String strategy: integerStrategies) {
+      rf.setFeatureSubsetStrategy(strategy);
+    }
+    String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", 
"0"};
+    for (String strategy: invalidStrategies) {
+      try {
+        rf.setFeatureSubsetStrategy(strategy);
+        Assert.fail("Expected exception to be thrown for invalid strategies");
+      } catch (Exception e) {
+        Assert.assertTrue(e instanceof IllegalArgumentException);
+      }
+    }
+
     RandomForestClassificationModel model = rf.fit(dataFrame);
 
     model.transform(dataFrame);

http://git-wip-us.apache.org/repos/asf/spark/blob/da60b34d/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index b6f793f..a873666 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.Map;
 
 import org.junit.After;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -80,6 +81,24 @@ public class JavaRandomForestRegressorSuite implements 
Serializable {
     for (String featureSubsetStrategy: 
RandomForestRegressor.supportedFeatureSubsetStrategies()) {
       rf.setFeatureSubsetStrategy(featureSubsetStrategy);
     }
+    String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
+    for (String strategy: realStrategies) {
+      rf.setFeatureSubsetStrategy(strategy);
+    }
+    String integerStrategies[] = {"1", "10", "100", "1000", "10000"};
+    for (String strategy: integerStrategies) {
+      rf.setFeatureSubsetStrategy(strategy);
+    }
+    String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", 
"0"};
+    for (String strategy: invalidStrategies) {
+      try {
+        rf.setFeatureSubsetStrategy(strategy);
+        Assert.fail("Expected exception to be thrown for invalid strategies");
+      } catch (Exception e) {
+        Assert.assertTrue(e instanceof IllegalArgumentException);
+      }
+    }
+
     RandomForestRegressionModel model = rf.fit(dataFrame);
 
     model.transform(dataFrame);

http://git-wip-us.apache.org/repos/asf/spark/blob/da60b34d/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index cd402b1..6db9ce1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -426,12 +426,48 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       (math.log(numFeatures) / math.log(2)).ceil.toInt)
     checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 
3.0).ceil.toInt)
 
+    val realStrategies = Array(".1", ".10", "0.10", "0.1", "0.9", "1.0")
+    for (strategy <- realStrategies) {
+      val expected = (strategy.toDouble * numFeatures).ceil.toInt
+      checkFeatureSubsetStrategy(numTrees = 1, strategy, expected)
+    }
+
+    val integerStrategies = Array("1", "10", "100", "1000", "10000")
+    for (strategy <- integerStrategies) {
+      val expected = if (strategy.toInt < numFeatures) strategy.toInt else 
numFeatures
+      checkFeatureSubsetStrategy(numTrees = 1, strategy, expected)
+    }
+
+    val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", 
"0")
+    for (invalidStrategy <- invalidStrategies) {
+      intercept[MatchError]{
+        val metadata =
+          DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, 
invalidStrategy)
+      }
+    }
+
     checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
     checkFeatureSubsetStrategy(numTrees = 2, "auto", 
math.sqrt(numFeatures).ceil.toInt)
     checkFeatureSubsetStrategy(numTrees = 2, "sqrt", 
math.sqrt(numFeatures).ceil.toInt)
     checkFeatureSubsetStrategy(numTrees = 2, "log2",
       (math.log(numFeatures) / math.log(2)).ceil.toInt)
     checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 
3.0).ceil.toInt)
+
+    for (strategy <- realStrategies) {
+      val expected = (strategy.toDouble * numFeatures).ceil.toInt
+      checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
+    }
+
+    for (strategy <- integerStrategies) {
+      val expected = if (strategy.toInt < numFeatures) strategy.toInt else 
numFeatures
+      checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
+    }
+    for (invalidStrategy <- invalidStrategies) {
+      intercept[MatchError]{
+        val metadata =
+          DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, 
invalidStrategy)
+      }
+    }
   }
 
   test("Binary classification with continuous features: subsampling features") 
{


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to