Repository: spark
Updated Branches:
  refs/heads/master 8bcad28a5 -> 3e3c3d58d


[SPARK-13706][ML] Add Python Example for Train Validation Split

## What changes were proposed in this pull request?

This pull request adds a python example for train validation split.

## How was this patch tested?

This was style tested through lint-python, generally tested with 
./dev/run-tests, and run in notebook and shell environments. It was viewed in 
docs locally with jekyll serve.

This contribution is my original work and I license it to Spark under its open 
source license.

Author: JeremyNixon <jnix...@gmail.com>

Closes #11547 from JeremyNixon/tvs_example.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3e3c3d58
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3e3c3d58
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3e3c3d58

Branch: refs/heads/master
Commit: 3e3c3d58d8d42b42e930d42eb70b0e84d02967eb
Parents: 8bcad28
Author: JeremyNixon <jnix...@gmail.com>
Authored: Thu Mar 10 09:09:56 2016 +0200
Committer: Nick Pentreath <nick.pentre...@gmail.com>
Committed: Thu Mar 10 09:18:15 2016 +0200

----------------------------------------------------------------------
 docs/ml-guide.md                                |  4 ++
 .../main/python/ml/train_validation_split.py    | 68 ++++++++++++++++++++
 2 files changed, 72 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3e3c3d58/docs/ml-guide.md
----------------------------------------------------------------------
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index a5a825f6..9916787 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -316,4 +316,8 @@ The `ParamMap` which produces the best evaluation metric is 
selected as the best
 {% include_example 
java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
 %}
 </div>
 
+<div data-lang="python">
+{% include_example python/ml/train_validation_split.py %}
+</div>
+
 </div>

http://git-wip-us.apache.org/repos/asf/spark/blob/3e3c3d58/examples/src/main/python/ml/train_validation_split.py
----------------------------------------------------------------------
diff --git a/examples/src/main/python/ml/train_validation_split.py 
b/examples/src/main/python/ml/train_validation_split.py
new file mode 100644
index 0000000..161a200
--- /dev/null
+++ b/examples/src/main/python/ml/train_validation_split.py
@@ -0,0 +1,68 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import SparkContext
+# $example on$
+from pyspark.ml.evaluation import RegressionEvaluator
+from pyspark.ml.regression import LinearRegression
+from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
+from pyspark.sql import SQLContext
+# $example off$
+
+"""
+This example demonstrates applying TrainValidationSplit to split data
+and preform model selection.
+Run with:
+
+  bin/spark-submit examples/src/main/python/ml/train_validation_split.py
+"""
+
+if __name__ == "__main__":
+    sc = SparkContext(appName="TrainValidationSplit")
+    sqlContext = SQLContext(sc)
+    # $example on$
+    # Prepare training and test data.
+    data = sqlContext.read.format("libsvm")\
+        .load("data/mllib/sample_linear_regression_data.txt")
+    train, test = data.randomSplit([0.7, 0.3])
+    lr = LinearRegression(maxIter=10, regParam=0.1)
+
+    # We use a ParamGridBuilder to construct a grid of parameters to search 
over.
+    # TrainValidationSplit will try all combinations of values and determine 
best model using
+    # the evaluator.
+    paramGrid = ParamGridBuilder()\
+        .addGrid(lr.regParam, [0.1, 0.01]) \
+        .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])\
+        .build()
+
+    # In this case the estimator is simply the linear regression.
+    # A TrainValidationSplit requires an Estimator, a set of Estimator 
ParamMaps, and an Evaluator.
+    tvs = TrainValidationSplit(estimator=lr,
+                               estimatorParamMaps=paramGrid,
+                               evaluator=RegressionEvaluator(),
+                               # 80% of the data will be used for training, 
20% for validation.
+                               trainRatio=0.8)
+
+    # Run TrainValidationSplit, and choose the best set of parameters.
+    model = tvs.fit(train)
+    # Make predictions on test data. model is the model with combination of 
parameters
+    # that performed best.
+    prediction = model.transform(test)
+    for row in prediction.take(5):
+        print(row)
+    # $example off$
+    sc.stop()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to