Repository: spark
Updated Branches:
  refs/heads/master 2a4e00ca4 -> e8ea5bafe


[SPARK-9910] [ML] User guide for train validation split

Author: martinzapletal <zapletal-mar...@email.cz>

Closes #8377 from zapletal-martin/SPARK-9910.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e8ea5baf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e8ea5baf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e8ea5baf

Branch: refs/heads/master
Commit: e8ea5bafee9ca734edf62021145d0c2d5491cba8
Parents: 2a4e00c
Author: martinzapletal <zapletal-mar...@email.cz>
Authored: Fri Aug 28 21:03:48 2015 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Fri Aug 28 21:03:48 2015 -0700

----------------------------------------------------------------------
 docs/ml-guide.md                                | 117 +++++++++++++++++++
 .../ml/JavaTrainValidationSplitExample.java     |  90 ++++++++++++++
 .../ml/TrainValidationSplitExample.scala        |  80 +++++++++++++
 3 files changed, 287 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e8ea5baf/docs/ml-guide.md
----------------------------------------------------------------------
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index ce53400..a92a285 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -872,3 +872,120 @@ jsc.stop();
 </div>
 
 </div>
+
+## Example: Model Selection via Train Validation Split
+In addition to  `CrossValidator` Spark also offers `TrainValidationSplit` for 
hyper-parameter tuning.
+`TrainValidationSplit` only evaluates each combination of parameters once as 
opposed to k times in
+ case of `CrossValidator`. It is therefore less expensive,
+ but will not produce as reliable results when the training dataset is not 
sufficiently large..
+
+`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in 
the `estimatorParamMaps` parameter,
+and an `Evaluator`.
+It begins by splitting the dataset into two parts using `trainRatio` parameter
+which are used as separate training and test datasets. For example with 
`$trainRatio=0.75$` (default),
+`TrainValidationSplit` will generate a training and test dataset pair where 
75% of the data is used for training and 25% for validation.
+Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the 
set of `ParamMap`s.
+For each combination of parameters, it trains the given `Estimator` and 
evaluates it using the given `Evaluator`.
+The `ParamMap` which produces the best evaluation metric is selected as the 
best option.
+`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` 
and the entire dataset.
+
+<div class="codetabs">
+
+<div data-lang="scala" markdown="1">
+{% highlight scala %}
+import org.apache.spark.ml.evaluation.RegressionEvaluator
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
+import org.apache.spark.mllib.util.MLUtils
+
+// Prepare training and test data.
+val data = MLUtils.loadLibSVMFile(sc, 
"data/mllib/sample_libsvm_data.txt").toDF()
+val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
+
+val lr = new LinearRegression()
+
+// We use a ParamGridBuilder to construct a grid of parameters to search over.
+// TrainValidationSplit will try all combinations of values and determine best 
model using
+// the evaluator.
+val paramGrid = new ParamGridBuilder()
+  .addGrid(lr.regParam, Array(0.1, 0.01))
+  .addGrid(lr.fitIntercept, Array(true, false))
+  .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
+  .build()
+
+// In this case the estimator is simply the linear regression.
+// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, 
and an Evaluator.
+val trainValidationSplit = new TrainValidationSplit()
+  .setEstimator(lr)
+  .setEvaluator(new RegressionEvaluator)
+  .setEstimatorParamMaps(paramGrid)
+
+// 80% of the data will be used for training and the remaining 20% for 
validation.
+trainValidationSplit.setTrainRatio(0.8)
+
+// Run train validation split, and choose the best set of parameters.
+val model = trainValidationSplit.fit(training)
+
+// Make predictions on test data. model is the model with combination of 
parameters
+// that performed best.
+model.transform(test)
+  .select("features", "label", "prediction")
+  .show()
+
+{% endhighlight %}
+</div>
+
+<div data-lang="java" markdown="1">
+{% highlight java %}
+import org.apache.spark.ml.evaluation.RegressionEvaluator;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.regression.LinearRegression;
+import org.apache.spark.ml.tuning.*;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.DataFrame;
+
+DataFrame data = jsql.createDataFrame(
+  MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"),
+  LabeledPoint.class);
+
+// Prepare training and test data.
+DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
+DataFrame training = splits[0];
+DataFrame test = splits[1];
+
+LinearRegression lr = new LinearRegression();
+
+// We use a ParamGridBuilder to construct a grid of parameters to search over.
+// TrainValidationSplit will try all combinations of values and determine best 
model using
+// the evaluator.
+ParamMap[] paramGrid = new ParamGridBuilder()
+  .addGrid(lr.regParam(), new double[] {0.1, 0.01})
+  .addGrid(lr.fitIntercept())
+  .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
+  .build();
+
+// In this case the estimator is simply the linear regression.
+// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, 
and an Evaluator.
+TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
+  .setEstimator(lr)
+  .setEvaluator(new RegressionEvaluator())
+  .setEstimatorParamMaps(paramGrid);
+
+// 80% of the data will be used for training and the remaining 20% for 
validation.
+trainValidationSplit.setTrainRatio(0.8);
+
+// Run train validation split, and choose the best set of parameters.
+TrainValidationSplitModel model = trainValidationSplit.fit(training);
+
+// Make predictions on test data. model is the model with combination of 
parameters
+// that performed best.
+model.transform(test)
+  .select("features", "label", "prediction")
+  .show();
+
+{% endhighlight %}
+</div>
+
+</div>

http://git-wip-us.apache.org/repos/asf/spark/blob/e8ea5baf/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
 
b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
new file mode 100644
index 0000000..23f834a
--- /dev/null
+++ 
b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.evaluation.RegressionEvaluator;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.regression.LinearRegression;
+import org.apache.spark.ml.tuning.*;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+/**
+ * A simple example demonstrating model selection using TrainValidationSplit.
+ *
+ * The example is based on {@link 
org.apache.spark.examples.ml.JavaSimpleParamsExample}
+ * using linear regression.
+ *
+ * Run with
+ * {{{
+ * bin/run-example ml.JavaTrainValidationSplitExample
+ * }}}
+ */
+public class JavaTrainValidationSplitExample {
+
+  public static void main(String[] args) {
+    SparkConf conf = new 
SparkConf().setAppName("JavaTrainValidationSplitExample");
+    JavaSparkContext jsc = new JavaSparkContext(conf);
+    SQLContext jsql = new SQLContext(jsc);
+
+    DataFrame data = jsql.createDataFrame(
+      MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"),
+      LabeledPoint.class);
+
+    // Prepare training and test data.
+    DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
+    DataFrame training = splits[0];
+    DataFrame test = splits[1];
+
+    LinearRegression lr = new LinearRegression();
+
+    // We use a ParamGridBuilder to construct a grid of parameters to search 
over.
+    // TrainValidationSplit will try all combinations of values and determine 
best model using
+    // the evaluator.
+    ParamMap[] paramGrid = new ParamGridBuilder()
+      .addGrid(lr.regParam(), new double[] {0.1, 0.01})
+      .addGrid(lr.fitIntercept())
+      .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
+      .build();
+
+    // In this case the estimator is simply the linear regression.
+    // A TrainValidationSplit requires an Estimator, a set of Estimator 
ParamMaps, and an Evaluator.
+    TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
+      .setEstimator(lr)
+      .setEvaluator(new RegressionEvaluator())
+      .setEstimatorParamMaps(paramGrid);
+
+    // 80% of the data will be used for training and the remaining 20% for 
validation.
+    trainValidationSplit.setTrainRatio(0.8);
+
+    // Run train validation split, and choose the best set of parameters.
+    TrainValidationSplitModel model = trainValidationSplit.fit(training);
+
+    // Make predictions on test data. model is the model with combination of 
parameters
+    // that performed best.
+    model.transform(test)
+      .select("features", "label", "prediction")
+      .show();
+
+    jsc.stop();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e8ea5baf/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala
----------------------------------------------------------------------
diff --git 
a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala
 
b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala
new file mode 100644
index 0000000..1abdf21
--- /dev/null
+++ 
b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import org.apache.spark.ml.evaluation.RegressionEvaluator
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.{SparkConf, SparkContext}
+
+/**
+ * A simple example demonstrating model selection using TrainValidationSplit.
+ *
+ * The example is based on [[SimpleParamsExample]] using linear regression.
+ * Run with
+ * {{{
+ * bin/run-example ml.TrainValidationSplitExample
+ * }}}
+ */
+object TrainValidationSplitExample {
+
+  def main(args: Array[String]): Unit = {
+    val conf = new SparkConf().setAppName("TrainValidationSplitExample")
+    val sc = new SparkContext(conf)
+    val sqlContext = new SQLContext(sc)
+    import sqlContext.implicits._
+
+    // Prepare training and test data.
+    val data = MLUtils.loadLibSVMFile(sc, 
"data/mllib/sample_libsvm_data.txt").toDF()
+    val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
+
+    val lr = new LinearRegression()
+
+    // We use a ParamGridBuilder to construct a grid of parameters to search 
over.
+    // TrainValidationSplit will try all combinations of values and determine 
best model using
+    // the evaluator.
+    val paramGrid = new ParamGridBuilder()
+      .addGrid(lr.regParam, Array(0.1, 0.01))
+      .addGrid(lr.fitIntercept, Array(true, false))
+      .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
+      .build()
+
+    // In this case the estimator is simply the linear regression.
+    // A TrainValidationSplit requires an Estimator, a set of Estimator 
ParamMaps, and an Evaluator.
+    val trainValidationSplit = new TrainValidationSplit()
+      .setEstimator(lr)
+      .setEvaluator(new RegressionEvaluator)
+      .setEstimatorParamMaps(paramGrid)
+
+    // 80% of the data will be used for training and the remaining 20% for 
validation.
+    trainValidationSplit.setTrainRatio(0.8)
+
+    // Run train validation split, and choose the best set of parameters.
+    val model = trainValidationSplit.fit(training)
+
+    // Make predictions on test data. model is the model with combination of 
parameters
+    // that performed best.
+    model.transform(test)
+      .select("features", "label", "prediction")
+      .show()
+
+    sc.stop()
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to