Repository: spark
Updated Branches:
  refs/heads/master e00cac989 -> 034ae305c


http://git-wip-us.apache.org/repos/asf/spark/blob/034ae305/python/pyspark/ml/tests/test_persistence.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests/test_persistence.py 
b/python/pyspark/ml/tests/test_persistence.py
new file mode 100644
index 0000000..b5a2e16
--- /dev/null
+++ b/python/pyspark/ml/tests/test_persistence.py
@@ -0,0 +1,369 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+from shutil import rmtree
+import sys
+import tempfile
+if sys.version_info[:2] <= (2, 6):
+    try:
+        import unittest2 as unittest
+    except ImportError:
+        sys.stderr.write('Please install unittest2 to test with Python 2.6 or 
earlier')
+        sys.exit(1)
+else:
+    import unittest
+
+from pyspark.ml import Transformer
+from pyspark.ml.classification import DecisionTreeClassifier, 
LogisticRegression, OneVsRest, \
+    OneVsRestModel
+from pyspark.ml.feature import Binarizer, HashingTF, PCA
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.param import Params
+from pyspark.ml.pipeline import Pipeline, PipelineModel
+from pyspark.ml.regression import DecisionTreeRegressor, LinearRegression
+from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWriter
+from pyspark.ml.wrapper import JavaParams
+from pyspark.testing.mlutils import MockUnaryTransformer, SparkSessionTestCase
+
+
+class PersistenceTest(SparkSessionTestCase):
+
+    def test_linear_regression(self):
+        lr = LinearRegression(maxIter=1)
+        path = tempfile.mkdtemp()
+        lr_path = path + "/lr"
+        lr.save(lr_path)
+        lr2 = LinearRegression.load(lr_path)
+        self.assertEqual(lr.uid, lr2.uid)
+        self.assertEqual(type(lr.uid), type(lr2.uid))
+        self.assertEqual(lr2.uid, lr2.maxIter.parent,
+                         "Loaded LinearRegression instance uid (%s) did not 
match Param's uid (%s)"
+                         % (lr2.uid, lr2.maxIter.parent))
+        self.assertEqual(lr._defaultParamMap[lr.maxIter], 
lr2._defaultParamMap[lr2.maxIter],
+                         "Loaded LinearRegression instance default params did 
not match " +
+                         "original defaults")
+        try:
+            rmtree(path)
+        except OSError:
+            pass
+
+    def test_linear_regression_pmml_basic(self):
+        # Most of the validation is done in the Scala side, here we just check
+        # that we output text rather than parquet (e.g. that the format flag
+        # was respected).
+        df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+                                         (0.0, 2.0, Vectors.sparse(1, [], 
[]))],
+                                        ["label", "weight", "features"])
+        lr = LinearRegression(maxIter=1)
+        model = lr.fit(df)
+        path = tempfile.mkdtemp()
+        lr_path = path + "/lr-pmml"
+        model.write().format("pmml").save(lr_path)
+        pmml_text_list = self.sc.textFile(lr_path).collect()
+        pmml_text = "\n".join(pmml_text_list)
+        self.assertIn("Apache Spark", pmml_text)
+        self.assertIn("PMML", pmml_text)
+
+    def test_logistic_regression(self):
+        lr = LogisticRegression(maxIter=1)
+        path = tempfile.mkdtemp()
+        lr_path = path + "/logreg"
+        lr.save(lr_path)
+        lr2 = LogisticRegression.load(lr_path)
+        self.assertEqual(lr2.uid, lr2.maxIter.parent,
+                         "Loaded LogisticRegression instance uid (%s) "
+                         "did not match Param's uid (%s)"
+                         % (lr2.uid, lr2.maxIter.parent))
+        self.assertEqual(lr._defaultParamMap[lr.maxIter], 
lr2._defaultParamMap[lr2.maxIter],
+                         "Loaded LogisticRegression instance default params 
did not match " +
+                         "original defaults")
+        try:
+            rmtree(path)
+        except OSError:
+            pass
+
+    def _compare_params(self, m1, m2, param):
+        """
+        Compare 2 ML Params instances for the given param, and assert both 
have the same param value
+        and parent. The param must be a parameter of m1.
+        """
+        # Prevent key not found error in case of some param in neither 
paramMap nor defaultParamMap.
+        if m1.isDefined(param):
+            paramValue1 = m1.getOrDefault(param)
+            paramValue2 = m2.getOrDefault(m2.getParam(param.name))
+            if isinstance(paramValue1, Params):
+                self._compare_pipelines(paramValue1, paramValue2)
+            else:
+                self.assertEqual(paramValue1, paramValue2)  # for general 
types param
+            # Assert parents are equal
+            self.assertEqual(param.parent, m2.getParam(param.name).parent)
+        else:
+            # If m1 is not defined param, then m2 should not, too. See 
SPARK-14931.
+            self.assertFalse(m2.isDefined(m2.getParam(param.name)))
+
+    def _compare_pipelines(self, m1, m2):
+        """
+        Compare 2 ML types, asserting that they are equivalent.
+        This currently supports:
+         - basic types
+         - Pipeline, PipelineModel
+         - OneVsRest, OneVsRestModel
+        This checks:
+         - uid
+         - type
+         - Param values and parents
+        """
+        self.assertEqual(m1.uid, m2.uid)
+        self.assertEqual(type(m1), type(m2))
+        if isinstance(m1, JavaParams) or isinstance(m1, Transformer):
+            self.assertEqual(len(m1.params), len(m2.params))
+            for p in m1.params:
+                self._compare_params(m1, m2, p)
+        elif isinstance(m1, Pipeline):
+            self.assertEqual(len(m1.getStages()), len(m2.getStages()))
+            for s1, s2 in zip(m1.getStages(), m2.getStages()):
+                self._compare_pipelines(s1, s2)
+        elif isinstance(m1, PipelineModel):
+            self.assertEqual(len(m1.stages), len(m2.stages))
+            for s1, s2 in zip(m1.stages, m2.stages):
+                self._compare_pipelines(s1, s2)
+        elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel):
+            for p in m1.params:
+                self._compare_params(m1, m2, p)
+            if isinstance(m1, OneVsRestModel):
+                self.assertEqual(len(m1.models), len(m2.models))
+                for x, y in zip(m1.models, m2.models):
+                    self._compare_pipelines(x, y)
+        else:
+            raise RuntimeError("_compare_pipelines does not yet support type: 
%s" % type(m1))
+
+    def test_pipeline_persistence(self):
+        """
+        Pipeline[HashingTF, PCA]
+        """
+        temp_path = tempfile.mkdtemp()
+
+        try:
+            df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", 
"e"],)], ["words"])
+            tf = HashingTF(numFeatures=10, inputCol="words", 
outputCol="features")
+            pca = PCA(k=2, inputCol="features", outputCol="pca_features")
+            pl = Pipeline(stages=[tf, pca])
+            model = pl.fit(df)
+
+            pipeline_path = temp_path + "/pipeline"
+            pl.save(pipeline_path)
+            loaded_pipeline = Pipeline.load(pipeline_path)
+            self._compare_pipelines(pl, loaded_pipeline)
+
+            model_path = temp_path + "/pipeline-model"
+            model.save(model_path)
+            loaded_model = PipelineModel.load(model_path)
+            self._compare_pipelines(model, loaded_model)
+        finally:
+            try:
+                rmtree(temp_path)
+            except OSError:
+                pass
+
+    def test_nested_pipeline_persistence(self):
+        """
+        Pipeline[HashingTF, Pipeline[PCA]]
+        """
+        temp_path = tempfile.mkdtemp()
+
+        try:
+            df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", 
"e"],)], ["words"])
+            tf = HashingTF(numFeatures=10, inputCol="words", 
outputCol="features")
+            pca = PCA(k=2, inputCol="features", outputCol="pca_features")
+            p0 = Pipeline(stages=[pca])
+            pl = Pipeline(stages=[tf, p0])
+            model = pl.fit(df)
+
+            pipeline_path = temp_path + "/pipeline"
+            pl.save(pipeline_path)
+            loaded_pipeline = Pipeline.load(pipeline_path)
+            self._compare_pipelines(pl, loaded_pipeline)
+
+            model_path = temp_path + "/pipeline-model"
+            model.save(model_path)
+            loaded_model = PipelineModel.load(model_path)
+            self._compare_pipelines(model, loaded_model)
+        finally:
+            try:
+                rmtree(temp_path)
+            except OSError:
+                pass
+
+    def test_python_transformer_pipeline_persistence(self):
+        """
+        Pipeline[MockUnaryTransformer, Binarizer]
+        """
+        temp_path = tempfile.mkdtemp()
+
+        try:
+            df = self.spark.range(0, 10).toDF('input')
+            tf = MockUnaryTransformer(shiftVal=2)\
+                .setInputCol("input").setOutputCol("shiftedInput")
+            tf2 = Binarizer(threshold=6, inputCol="shiftedInput", 
outputCol="binarized")
+            pl = Pipeline(stages=[tf, tf2])
+            model = pl.fit(df)
+
+            pipeline_path = temp_path + "/pipeline"
+            pl.save(pipeline_path)
+            loaded_pipeline = Pipeline.load(pipeline_path)
+            self._compare_pipelines(pl, loaded_pipeline)
+
+            model_path = temp_path + "/pipeline-model"
+            model.save(model_path)
+            loaded_model = PipelineModel.load(model_path)
+            self._compare_pipelines(model, loaded_model)
+        finally:
+            try:
+                rmtree(temp_path)
+            except OSError:
+                pass
+
+    def test_onevsrest(self):
+        temp_path = tempfile.mkdtemp()
+        df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+                                         (1.0, Vectors.sparse(2, [], [])),
+                                         (2.0, Vectors.dense(0.5, 0.5))] * 10,
+                                        ["label", "features"])
+        lr = LogisticRegression(maxIter=5, regParam=0.01)
+        ovr = OneVsRest(classifier=lr)
+        model = ovr.fit(df)
+        ovrPath = temp_path + "/ovr"
+        ovr.save(ovrPath)
+        loadedOvr = OneVsRest.load(ovrPath)
+        self._compare_pipelines(ovr, loadedOvr)
+        modelPath = temp_path + "/ovrModel"
+        model.save(modelPath)
+        loadedModel = OneVsRestModel.load(modelPath)
+        self._compare_pipelines(model, loadedModel)
+
+    def test_decisiontree_classifier(self):
+        dt = DecisionTreeClassifier(maxDepth=1)
+        path = tempfile.mkdtemp()
+        dtc_path = path + "/dtc"
+        dt.save(dtc_path)
+        dt2 = DecisionTreeClassifier.load(dtc_path)
+        self.assertEqual(dt2.uid, dt2.maxDepth.parent,
+                         "Loaded DecisionTreeClassifier instance uid (%s) "
+                         "did not match Param's uid (%s)"
+                         % (dt2.uid, dt2.maxDepth.parent))
+        self.assertEqual(dt._defaultParamMap[dt.maxDepth], 
dt2._defaultParamMap[dt2.maxDepth],
+                         "Loaded DecisionTreeClassifier instance default 
params did not match " +
+                         "original defaults")
+        try:
+            rmtree(path)
+        except OSError:
+            pass
+
+    def test_decisiontree_regressor(self):
+        dt = DecisionTreeRegressor(maxDepth=1)
+        path = tempfile.mkdtemp()
+        dtr_path = path + "/dtr"
+        dt.save(dtr_path)
+        dt2 = DecisionTreeClassifier.load(dtr_path)
+        self.assertEqual(dt2.uid, dt2.maxDepth.parent,
+                         "Loaded DecisionTreeRegressor instance uid (%s) "
+                         "did not match Param's uid (%s)"
+                         % (dt2.uid, dt2.maxDepth.parent))
+        self.assertEqual(dt._defaultParamMap[dt.maxDepth], 
dt2._defaultParamMap[dt2.maxDepth],
+                         "Loaded DecisionTreeRegressor instance default params 
did not match " +
+                         "original defaults")
+        try:
+            rmtree(path)
+        except OSError:
+            pass
+
+    def test_default_read_write(self):
+        temp_path = tempfile.mkdtemp()
+
+        lr = LogisticRegression()
+        lr.setMaxIter(50)
+        lr.setThreshold(.75)
+        writer = DefaultParamsWriter(lr)
+
+        savePath = temp_path + "/lr"
+        writer.save(savePath)
+
+        reader = DefaultParamsReadable.read()
+        lr2 = reader.load(savePath)
+
+        self.assertEqual(lr.uid, lr2.uid)
+        self.assertEqual(lr.extractParamMap(), lr2.extractParamMap())
+
+        # test overwrite
+        lr.setThreshold(.8)
+        writer.overwrite().save(savePath)
+
+        reader = DefaultParamsReadable.read()
+        lr3 = reader.load(savePath)
+
+        self.assertEqual(lr.uid, lr3.uid)
+        self.assertEqual(lr.extractParamMap(), lr3.extractParamMap())
+
+    def test_default_read_write_default_params(self):
+        lr = LogisticRegression()
+        self.assertFalse(lr.isSet(lr.getParam("threshold")))
+
+        lr.setMaxIter(50)
+        lr.setThreshold(.75)
+
+        # `threshold` is set by user, default param `predictionCol` is not set 
by user.
+        self.assertTrue(lr.isSet(lr.getParam("threshold")))
+        self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
+        self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
+
+        writer = DefaultParamsWriter(lr)
+        metadata = json.loads(writer._get_metadata_to_save(lr, self.sc))
+        self.assertTrue("defaultParamMap" in metadata)
+
+        reader = DefaultParamsReadable.read()
+        metadataStr = json.dumps(metadata, separators=[',',  ':'])
+        loadedMetadata = reader._parseMetaData(metadataStr, )
+        reader.getAndSetParams(lr, loadedMetadata)
+
+        self.assertTrue(lr.isSet(lr.getParam("threshold")))
+        self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
+        self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
+
+        # manually create metadata without `defaultParamMap` section.
+        del metadata['defaultParamMap']
+        metadataStr = json.dumps(metadata, separators=[',',  ':'])
+        loadedMetadata = reader._parseMetaData(metadataStr, )
+        with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` 
section not found"):
+            reader.getAndSetParams(lr, loadedMetadata)
+
+        # Prior to 2.4.0, metadata doesn't have `defaultParamMap`.
+        metadata['sparkVersion'] = '2.3.0'
+        metadataStr = json.dumps(metadata, separators=[',',  ':'])
+        loadedMetadata = reader._parseMetaData(metadataStr, )
+        reader.getAndSetParams(lr, loadedMetadata)
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_persistence import *
+
+    try:
+        import xmlrunner
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/034ae305/python/pyspark/ml/tests/test_pipeline.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests/test_pipeline.py 
b/python/pyspark/ml/tests/test_pipeline.py
new file mode 100644
index 0000000..31ef02c
--- /dev/null
+++ b/python/pyspark/ml/tests/test_pipeline.py
@@ -0,0 +1,77 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import sys
+if sys.version_info[:2] <= (2, 6):
+    try:
+        import unittest2 as unittest
+    except ImportError:
+        sys.stderr.write('Please install unittest2 to test with Python 2.6 or 
earlier')
+        sys.exit(1)
+else:
+    import unittest
+
+from pyspark.ml.pipeline import Pipeline
+from pyspark.testing.mlutils import MockDataset, MockEstimator, 
MockTransformer, PySparkTestCase
+
+
+class PipelineTests(PySparkTestCase):
+
+    def test_pipeline(self):
+        dataset = MockDataset()
+        estimator0 = MockEstimator()
+        transformer1 = MockTransformer()
+        estimator2 = MockEstimator()
+        transformer3 = MockTransformer()
+        pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, 
transformer3])
+        pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, 
transformer1.fake: 1})
+        model0, transformer1, model2, transformer3 = pipeline_model.stages
+        self.assertEqual(0, model0.dataset_index)
+        self.assertEqual(0, model0.getFake())
+        self.assertEqual(1, transformer1.dataset_index)
+        self.assertEqual(1, transformer1.getFake())
+        self.assertEqual(2, dataset.index)
+        self.assertIsNone(model2.dataset_index, "The last model shouldn't be 
called in fit.")
+        self.assertIsNone(transformer3.dataset_index,
+                          "The last transformer shouldn't be called in fit.")
+        dataset = pipeline_model.transform(dataset)
+        self.assertEqual(2, model0.dataset_index)
+        self.assertEqual(3, transformer1.dataset_index)
+        self.assertEqual(4, model2.dataset_index)
+        self.assertEqual(5, transformer3.dataset_index)
+        self.assertEqual(6, dataset.index)
+
+    def test_identity_pipeline(self):
+        dataset = MockDataset()
+
+        def doTransform(pipeline):
+            pipeline_model = pipeline.fit(dataset)
+            return pipeline_model.transform(dataset)
+        # check that empty pipeline did not perform any transformation
+        self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index)
+        # check that failure to set stages param will raise KeyError for 
missing param
+        self.assertRaises(KeyError, lambda: doTransform(Pipeline()))
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_pipeline import *
+
+    try:
+        import xmlrunner
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/034ae305/python/pyspark/ml/tests/test_stat.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests/test_stat.py 
b/python/pyspark/ml/tests/test_stat.py
new file mode 100644
index 0000000..bdc4853
--- /dev/null
+++ b/python/pyspark/ml/tests/test_stat.py
@@ -0,0 +1,58 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+if sys.version_info[:2] <= (2, 6):
+    try:
+        import unittest2 as unittest
+    except ImportError:
+        sys.stderr.write('Please install unittest2 to test with Python 2.6 or 
earlier')
+        sys.exit(1)
+else:
+    import unittest
+
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.stat import ChiSquareTest
+from pyspark.sql import DataFrame
+from pyspark.testing.mlutils import SparkSessionTestCase
+
+
+class ChiSquareTestTests(SparkSessionTestCase):
+
+    def test_chisquaretest(self):
+        data = [[0, Vectors.dense([0, 1, 2])],
+                [1, Vectors.dense([1, 1, 1])],
+                [2, Vectors.dense([2, 1, 0])]]
+        df = self.spark.createDataFrame(data, ['label', 'feat'])
+        res = ChiSquareTest.test(df, 'feat', 'label')
+        # This line is hitting the collect bug described in #17218, commented 
for now.
+        # pValues = res.select("degreesOfFreedom").collect())
+        self.assertIsInstance(res, DataFrame)
+        fieldNames = set(field.name for field in res.schema.fields)
+        expectedFields = ["pValues", "degreesOfFreedom", "statistics"]
+        self.assertTrue(all(field in fieldNames for field in expectedFields))
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_stat import *
+
+    try:
+        import xmlrunner
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/034ae305/python/pyspark/ml/tests/test_training_summary.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests/test_training_summary.py 
b/python/pyspark/ml/tests/test_training_summary.py
new file mode 100644
index 0000000..d5464f7
--- /dev/null
+++ b/python/pyspark/ml/tests/test_training_summary.py
@@ -0,0 +1,258 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+if sys.version_info[:2] <= (2, 6):
+    try:
+        import unittest2 as unittest
+    except ImportError:
+        sys.stderr.write('Please install unittest2 to test with Python 2.6 or 
earlier')
+        sys.exit(1)
+else:
+    import unittest
+
+if sys.version > '3':
+    basestring = str
+
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
+from pyspark.sql import DataFrame
+from pyspark.testing.mlutils import SparkSessionTestCase
+
+
+class TrainingSummaryTest(SparkSessionTestCase):
+
+    def test_linear_regression_summary(self):
+        df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+                                         (0.0, 2.0, Vectors.sparse(1, [], 
[]))],
+                                        ["label", "weight", "features"])
+        lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", 
weightCol="weight",
+                              fitIntercept=False)
+        model = lr.fit(df)
+        self.assertTrue(model.hasSummary)
+        s = model.summary
+        # test that api is callable and returns expected types
+        self.assertGreater(s.totalIterations, 0)
+        self.assertTrue(isinstance(s.predictions, DataFrame))
+        self.assertEqual(s.predictionCol, "prediction")
+        self.assertEqual(s.labelCol, "label")
+        self.assertEqual(s.featuresCol, "features")
+        objHist = s.objectiveHistory
+        self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], 
float))
+        self.assertAlmostEqual(s.explainedVariance, 0.25, 2)
+        self.assertAlmostEqual(s.meanAbsoluteError, 0.0)
+        self.assertAlmostEqual(s.meanSquaredError, 0.0)
+        self.assertAlmostEqual(s.rootMeanSquaredError, 0.0)
+        self.assertAlmostEqual(s.r2, 1.0, 2)
+        self.assertAlmostEqual(s.r2adj, 1.0, 2)
+        self.assertTrue(isinstance(s.residuals, DataFrame))
+        self.assertEqual(s.numInstances, 2)
+        self.assertEqual(s.degreesOfFreedom, 1)
+        devResiduals = s.devianceResiduals
+        self.assertTrue(isinstance(devResiduals, list) and 
isinstance(devResiduals[0], float))
+        coefStdErr = s.coefficientStandardErrors
+        self.assertTrue(isinstance(coefStdErr, list) and 
isinstance(coefStdErr[0], float))
+        tValues = s.tValues
+        self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], 
float))
+        pValues = s.pValues
+        self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], 
float))
+        # test evaluation (with training dataset) produces a summary with same 
values
+        # one check is enough to verify a summary is returned
+        # The child class LinearRegressionTrainingSummary runs full test
+        sameSummary = model.evaluate(df)
+        self.assertAlmostEqual(sameSummary.explainedVariance, 
s.explainedVariance)
+
+    def test_glr_summary(self):
+        from pyspark.ml.linalg import Vectors
+        df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+                                         (0.0, 2.0, Vectors.sparse(1, [], 
[]))],
+                                        ["label", "weight", "features"])
+        glr = GeneralizedLinearRegression(family="gaussian", link="identity", 
weightCol="weight",
+                                          fitIntercept=False)
+        model = glr.fit(df)
+        self.assertTrue(model.hasSummary)
+        s = model.summary
+        # test that api is callable and returns expected types
+        self.assertEqual(s.numIterations, 1)  # this should default to a 
single iteration of WLS
+        self.assertTrue(isinstance(s.predictions, DataFrame))
+        self.assertEqual(s.predictionCol, "prediction")
+        self.assertEqual(s.numInstances, 2)
+        self.assertTrue(isinstance(s.residuals(), DataFrame))
+        self.assertTrue(isinstance(s.residuals("pearson"), DataFrame))
+        coefStdErr = s.coefficientStandardErrors
+        self.assertTrue(isinstance(coefStdErr, list) and 
isinstance(coefStdErr[0], float))
+        tValues = s.tValues
+        self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], 
float))
+        pValues = s.pValues
+        self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], 
float))
+        self.assertEqual(s.degreesOfFreedom, 1)
+        self.assertEqual(s.residualDegreeOfFreedom, 1)
+        self.assertEqual(s.residualDegreeOfFreedomNull, 2)
+        self.assertEqual(s.rank, 1)
+        self.assertTrue(isinstance(s.solver, basestring))
+        self.assertTrue(isinstance(s.aic, float))
+        self.assertTrue(isinstance(s.deviance, float))
+        self.assertTrue(isinstance(s.nullDeviance, float))
+        self.assertTrue(isinstance(s.dispersion, float))
+        # test evaluation (with training dataset) produces a summary with same 
values
+        # one check is enough to verify a summary is returned
+        # The child class GeneralizedLinearRegressionTrainingSummary runs full 
test
+        sameSummary = model.evaluate(df)
+        self.assertAlmostEqual(sameSummary.deviance, s.deviance)
+
+    def test_binary_logistic_regression_summary(self):
+        df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+                                         (0.0, 2.0, Vectors.sparse(1, [], 
[]))],
+                                        ["label", "weight", "features"])
+        lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", 
fitIntercept=False)
+        model = lr.fit(df)
+        self.assertTrue(model.hasSummary)
+        s = model.summary
+        # test that api is callable and returns expected types
+        self.assertTrue(isinstance(s.predictions, DataFrame))
+        self.assertEqual(s.probabilityCol, "probability")
+        self.assertEqual(s.labelCol, "label")
+        self.assertEqual(s.featuresCol, "features")
+        self.assertEqual(s.predictionCol, "prediction")
+        objHist = s.objectiveHistory
+        self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], 
float))
+        self.assertGreater(s.totalIterations, 0)
+        self.assertTrue(isinstance(s.labels, list))
+        self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
+        self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
+        self.assertTrue(isinstance(s.precisionByLabel, list))
+        self.assertTrue(isinstance(s.recallByLabel, list))
+        self.assertTrue(isinstance(s.fMeasureByLabel(), list))
+        self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
+        self.assertTrue(isinstance(s.roc, DataFrame))
+        self.assertAlmostEqual(s.areaUnderROC, 1.0, 2)
+        self.assertTrue(isinstance(s.pr, DataFrame))
+        self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame))
+        self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
+        self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
+        self.assertAlmostEqual(s.accuracy, 1.0, 2)
+        self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2)
+        self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2)
+        self.assertAlmostEqual(s.weightedRecall, 1.0, 2)
+        self.assertAlmostEqual(s.weightedPrecision, 1.0, 2)
+        self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2)
+        self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2)
+        # test evaluation (with training dataset) produces a summary with same 
values
+        # one check is enough to verify a summary is returned, Scala version 
runs full test
+        sameSummary = model.evaluate(df)
+        self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
+
+    def test_multiclass_logistic_regression_summary(self):
+        df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+                                         (0.0, 2.0, Vectors.sparse(1, [], [])),
+                                         (2.0, 2.0, Vectors.dense(2.0)),
+                                         (2.0, 2.0, Vectors.dense(1.9))],
+                                        ["label", "weight", "features"])
+        lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", 
fitIntercept=False)
+        model = lr.fit(df)
+        self.assertTrue(model.hasSummary)
+        s = model.summary
+        # test that api is callable and returns expected types
+        self.assertTrue(isinstance(s.predictions, DataFrame))
+        self.assertEqual(s.probabilityCol, "probability")
+        self.assertEqual(s.labelCol, "label")
+        self.assertEqual(s.featuresCol, "features")
+        self.assertEqual(s.predictionCol, "prediction")
+        objHist = s.objectiveHistory
+        self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], 
float))
+        self.assertGreater(s.totalIterations, 0)
+        self.assertTrue(isinstance(s.labels, list))
+        self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
+        self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
+        self.assertTrue(isinstance(s.precisionByLabel, list))
+        self.assertTrue(isinstance(s.recallByLabel, list))
+        self.assertTrue(isinstance(s.fMeasureByLabel(), list))
+        self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
+        self.assertAlmostEqual(s.accuracy, 0.75, 2)
+        self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2)
+        self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2)
+        self.assertAlmostEqual(s.weightedRecall, 0.75, 2)
+        self.assertAlmostEqual(s.weightedPrecision, 0.583, 2)
+        self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2)
+        self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2)
+        # test evaluation (with training dataset) produces a summary with same 
values
+        # one check is enough to verify a summary is returned, Scala version 
runs full test
+        sameSummary = model.evaluate(df)
+        self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)
+
+    def test_gaussian_mixture_summary(self):
+        data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), 
(Vectors.dense(10.0),),
+                (Vectors.sparse(1, [], []),)]
+        df = self.spark.createDataFrame(data, ["features"])
+        gmm = GaussianMixture(k=2)
+        model = gmm.fit(df)
+        self.assertTrue(model.hasSummary)
+        s = model.summary
+        self.assertTrue(isinstance(s.predictions, DataFrame))
+        self.assertEqual(s.probabilityCol, "probability")
+        self.assertTrue(isinstance(s.probability, DataFrame))
+        self.assertEqual(s.featuresCol, "features")
+        self.assertEqual(s.predictionCol, "prediction")
+        self.assertTrue(isinstance(s.cluster, DataFrame))
+        self.assertEqual(len(s.clusterSizes), 2)
+        self.assertEqual(s.k, 2)
+        self.assertEqual(s.numIter, 3)
+
+    def test_bisecting_kmeans_summary(self):
+        data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), 
(Vectors.dense(10.0),),
+                (Vectors.sparse(1, [], []),)]
+        df = self.spark.createDataFrame(data, ["features"])
+        bkm = BisectingKMeans(k=2)
+        model = bkm.fit(df)
+        self.assertTrue(model.hasSummary)
+        s = model.summary
+        self.assertTrue(isinstance(s.predictions, DataFrame))
+        self.assertEqual(s.featuresCol, "features")
+        self.assertEqual(s.predictionCol, "prediction")
+        self.assertTrue(isinstance(s.cluster, DataFrame))
+        self.assertEqual(len(s.clusterSizes), 2)
+        self.assertEqual(s.k, 2)
+        self.assertEqual(s.numIter, 20)
+
+    def test_kmeans_summary(self):
+        data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
+                (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
+        df = self.spark.createDataFrame(data, ["features"])
+        kmeans = KMeans(k=2, seed=1)
+        model = kmeans.fit(df)
+        self.assertTrue(model.hasSummary)
+        s = model.summary
+        self.assertTrue(isinstance(s.predictions, DataFrame))
+        self.assertEqual(s.featuresCol, "features")
+        self.assertEqual(s.predictionCol, "prediction")
+        self.assertTrue(isinstance(s.cluster, DataFrame))
+        self.assertEqual(len(s.clusterSizes), 2)
+        self.assertEqual(s.k, 2)
+        self.assertEqual(s.numIter, 1)
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_training_summary import *
+
+    try:
+        import xmlrunner
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/034ae305/python/pyspark/ml/tests/test_tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests/test_tuning.py 
b/python/pyspark/ml/tests/test_tuning.py
new file mode 100644
index 0000000..af00d1d
--- /dev/null
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -0,0 +1,552 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import tempfile
+if sys.version_info[:2] <= (2, 6):
+    try:
+        import unittest2 as unittest
+    except ImportError:
+        sys.stderr.write('Please install unittest2 to test with Python 2.6 or 
earlier')
+        sys.exit(1)
+else:
+    import unittest
+
+from pyspark.ml import Estimator, Model
+from pyspark.ml.classification import LogisticRegression, 
LogisticRegressionModel, OneVsRest
+from pyspark.ml.evaluation import BinaryClassificationEvaluator, \
+    MulticlassClassificationEvaluator, RegressionEvaluator
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.param import Param, Params
+from pyspark.ml.tuning import CrossValidator, CrossValidatorModel, 
ParamGridBuilder, \
+    TrainValidationSplit, TrainValidationSplitModel
+from pyspark.sql.functions import rand
+from pyspark.testing.mlutils import SparkSessionTestCase
+
+
+class HasInducedError(Params):
+
+    def __init__(self):
+        super(HasInducedError, self).__init__()
+        self.inducedError = Param(self, "inducedError",
+                                  "Uniformly-distributed error added to 
feature")
+
+    def getInducedError(self):
+        return self.getOrDefault(self.inducedError)
+
+
+class InducedErrorModel(Model, HasInducedError):
+
+    def __init__(self):
+        super(InducedErrorModel, self).__init__()
+
+    def _transform(self, dataset):
+        return dataset.withColumn("prediction",
+                                  dataset.feature + (rand(0) * 
self.getInducedError()))
+
+
+class InducedErrorEstimator(Estimator, HasInducedError):
+
+    def __init__(self, inducedError=1.0):
+        super(InducedErrorEstimator, self).__init__()
+        self._set(inducedError=inducedError)
+
+    def _fit(self, dataset):
+        model = InducedErrorModel()
+        self._copyValues(model)
+        return model
+
+
+class CrossValidatorTests(SparkSessionTestCase):
+
+    def test_copy(self):
+        dataset = self.spark.createDataFrame([
+            (10, 10.0),
+            (50, 50.0),
+            (100, 100.0),
+            (500, 500.0)] * 10,
+            ["feature", "label"])
+
+        iee = InducedErrorEstimator()
+        evaluator = RegressionEvaluator(metricName="rmse")
+
+        grid = (ParamGridBuilder()
+                .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
+                .build())
+        cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, 
evaluator=evaluator)
+        cvCopied = cv.copy()
+        self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid)
+
+        cvModel = cv.fit(dataset)
+        cvModelCopied = cvModel.copy()
+        for index in range(len(cvModel.avgMetrics)):
+            self.assertTrue(abs(cvModel.avgMetrics[index] - 
cvModelCopied.avgMetrics[index])
+                            < 0.0001)
+
+    def test_fit_minimize_metric(self):
+        dataset = self.spark.createDataFrame([
+            (10, 10.0),
+            (50, 50.0),
+            (100, 100.0),
+            (500, 500.0)] * 10,
+            ["feature", "label"])
+
+        iee = InducedErrorEstimator()
+        evaluator = RegressionEvaluator(metricName="rmse")
+
+        grid = (ParamGridBuilder()
+                .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
+                .build())
+        cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, 
evaluator=evaluator)
+        cvModel = cv.fit(dataset)
+        bestModel = cvModel.bestModel
+        bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
+
+        self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
+                         "Best model should have zero induced error")
+        self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
+
+    def test_fit_maximize_metric(self):
+        dataset = self.spark.createDataFrame([
+            (10, 10.0),
+            (50, 50.0),
+            (100, 100.0),
+            (500, 500.0)] * 10,
+            ["feature", "label"])
+
+        iee = InducedErrorEstimator()
+        evaluator = RegressionEvaluator(metricName="r2")
+
+        grid = (ParamGridBuilder()
+                .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
+                .build())
+        cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, 
evaluator=evaluator)
+        cvModel = cv.fit(dataset)
+        bestModel = cvModel.bestModel
+        bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
+
+        self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
+                         "Best model should have zero induced error")
+        self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+
+    def test_param_grid_type_coercion(self):
+        lr = LogisticRegression(maxIter=10)
+        paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build()
+        for param in paramGrid:
+            for v in param.values():
+                assert(type(v) == float)
+
+    def test_save_load_trained_model(self):
+        # This tests saving and loading the trained model only.
+        # Save/load for CrossValidator will be added later: SPARK-13786
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+        evaluator = BinaryClassificationEvaluator()
+        cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, 
evaluator=evaluator)
+        cvModel = cv.fit(dataset)
+        lrModel = cvModel.bestModel
+
+        cvModelPath = temp_path + "/cvModel"
+        lrModel.save(cvModelPath)
+        loadedLrModel = LogisticRegressionModel.load(cvModelPath)
+        self.assertEqual(loadedLrModel.uid, lrModel.uid)
+        self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
+
+    def test_save_load_simple_estimator(self):
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+        evaluator = BinaryClassificationEvaluator()
+
+        # test save/load of CrossValidator
+        cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, 
evaluator=evaluator)
+        cvModel = cv.fit(dataset)
+        cvPath = temp_path + "/cv"
+        cv.save(cvPath)
+        loadedCV = CrossValidator.load(cvPath)
+        self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
+        self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
+        self.assertEqual(loadedCV.getEstimatorParamMaps(), 
cv.getEstimatorParamMaps())
+
+        # test save/load of CrossValidatorModel
+        cvModelPath = temp_path + "/cvModel"
+        cvModel.save(cvModelPath)
+        loadedModel = CrossValidatorModel.load(cvModelPath)
+        self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+
+    def test_parallel_evaluation(self):
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build()
+        evaluator = BinaryClassificationEvaluator()
+
+        # test save/load of CrossValidator
+        cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, 
evaluator=evaluator)
+        cv.setParallelism(1)
+        cvSerialModel = cv.fit(dataset)
+        cv.setParallelism(2)
+        cvParallelModel = cv.fit(dataset)
+        self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics)
+
+    def test_expose_sub_models(self):
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+        evaluator = BinaryClassificationEvaluator()
+
+        numFolds = 3
+        cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, 
evaluator=evaluator,
+                            numFolds=numFolds, collectSubModels=True)
+
+        def checkSubModels(subModels):
+            self.assertEqual(len(subModels), numFolds)
+            for i in range(numFolds):
+                self.assertEqual(len(subModels[i]), len(grid))
+
+        cvModel = cv.fit(dataset)
+        checkSubModels(cvModel.subModels)
+
+        # Test the default value for option "persistSubModel" to be "true"
+        testSubPath = temp_path + "/testCrossValidatorSubModels"
+        savingPathWithSubModels = testSubPath + "cvModel3"
+        cvModel.save(savingPathWithSubModels)
+        cvModel3 = CrossValidatorModel.load(savingPathWithSubModels)
+        checkSubModels(cvModel3.subModels)
+        cvModel4 = cvModel3.copy()
+        checkSubModels(cvModel4.subModels)
+
+        savingPathWithoutSubModels = testSubPath + "cvModel2"
+        cvModel.write().option("persistSubModels", 
"false").save(savingPathWithoutSubModels)
+        cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels)
+        self.assertEqual(cvModel2.subModels, None)
+
+        for i in range(numFolds):
+            for j in range(len(grid)):
+                self.assertEqual(cvModel.subModels[i][j].uid, 
cvModel3.subModels[i][j].uid)
+
+    def test_save_load_nested_estimator(self):
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+
+        ova = OneVsRest(classifier=LogisticRegression())
+        lr1 = LogisticRegression().setMaxIter(100)
+        lr2 = LogisticRegression().setMaxIter(150)
+        grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
+        evaluator = MulticlassClassificationEvaluator()
+
+        # test save/load of CrossValidator
+        cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, 
evaluator=evaluator)
+        cvModel = cv.fit(dataset)
+        cvPath = temp_path + "/cv"
+        cv.save(cvPath)
+        loadedCV = CrossValidator.load(cvPath)
+        self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
+        self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
+
+        originalParamMap = cv.getEstimatorParamMaps()
+        loadedParamMap = loadedCV.getEstimatorParamMaps()
+        for i, param in enumerate(loadedParamMap):
+            for p in param:
+                if p.name == "classifier":
+                    self.assertEqual(param[p].uid, originalParamMap[i][p].uid)
+                else:
+                    self.assertEqual(param[p], originalParamMap[i][p])
+
+        # test save/load of CrossValidatorModel
+        cvModelPath = temp_path + "/cvModel"
+        cvModel.save(cvModelPath)
+        loadedModel = CrossValidatorModel.load(cvModelPath)
+        self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+
+
+class TrainValidationSplitTests(SparkSessionTestCase):
+
+    def test_fit_minimize_metric(self):
+        dataset = self.spark.createDataFrame([
+            (10, 10.0),
+            (50, 50.0),
+            (100, 100.0),
+            (500, 500.0)] * 10,
+            ["feature", "label"])
+
+        iee = InducedErrorEstimator()
+        evaluator = RegressionEvaluator(metricName="rmse")
+
+        grid = ParamGridBuilder() \
+            .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \
+            .build()
+        tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, 
evaluator=evaluator)
+        tvsModel = tvs.fit(dataset)
+        bestModel = tvsModel.bestModel
+        bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
+        validationMetrics = tvsModel.validationMetrics
+
+        self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
+                         "Best model should have zero induced error")
+        self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
+        self.assertEqual(len(grid), len(validationMetrics),
+                         "validationMetrics has the same size of grid 
parameter")
+        self.assertEqual(0.0, min(validationMetrics))
+
+    def test_fit_maximize_metric(self):
+        dataset = self.spark.createDataFrame([
+            (10, 10.0),
+            (50, 50.0),
+            (100, 100.0),
+            (500, 500.0)] * 10,
+            ["feature", "label"])
+
+        iee = InducedErrorEstimator()
+        evaluator = RegressionEvaluator(metricName="r2")
+
+        grid = ParamGridBuilder() \
+            .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \
+            .build()
+        tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, 
evaluator=evaluator)
+        tvsModel = tvs.fit(dataset)
+        bestModel = tvsModel.bestModel
+        bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
+        validationMetrics = tvsModel.validationMetrics
+
+        self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
+                         "Best model should have zero induced error")
+        self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+        self.assertEqual(len(grid), len(validationMetrics),
+                         "validationMetrics has the same size of grid 
parameter")
+        self.assertEqual(1.0, max(validationMetrics))
+
+    def test_save_load_trained_model(self):
+        # This tests saving and loading the trained model only.
+        # Save/load for TrainValidationSplit will be added later: SPARK-13786
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+        evaluator = BinaryClassificationEvaluator()
+        tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, 
evaluator=evaluator)
+        tvsModel = tvs.fit(dataset)
+        lrModel = tvsModel.bestModel
+
+        tvsModelPath = temp_path + "/tvsModel"
+        lrModel.save(tvsModelPath)
+        loadedLrModel = LogisticRegressionModel.load(tvsModelPath)
+        self.assertEqual(loadedLrModel.uid, lrModel.uid)
+        self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
+
+    def test_save_load_simple_estimator(self):
+        # This tests saving and loading the trained model only.
+        # Save/load for TrainValidationSplit will be added later: SPARK-13786
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+        evaluator = BinaryClassificationEvaluator()
+        tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, 
evaluator=evaluator)
+        tvsModel = tvs.fit(dataset)
+
+        tvsPath = temp_path + "/tvs"
+        tvs.save(tvsPath)
+        loadedTvs = TrainValidationSplit.load(tvsPath)
+        self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
+        self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
+        self.assertEqual(loadedTvs.getEstimatorParamMaps(), 
tvs.getEstimatorParamMaps())
+
+        tvsModelPath = temp_path + "/tvsModel"
+        tvsModel.save(tvsModelPath)
+        loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+        self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+
+    def test_parallel_evaluation(self):
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build()
+        evaluator = BinaryClassificationEvaluator()
+        tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, 
evaluator=evaluator)
+        tvs.setParallelism(1)
+        tvsSerialModel = tvs.fit(dataset)
+        tvs.setParallelism(2)
+        tvsParallelModel = tvs.fit(dataset)
+        self.assertEqual(tvsSerialModel.validationMetrics, 
tvsParallelModel.validationMetrics)
+
+    def test_expose_sub_models(self):
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+        evaluator = BinaryClassificationEvaluator()
+        tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, 
evaluator=evaluator,
+                                   collectSubModels=True)
+        tvsModel = tvs.fit(dataset)
+        self.assertEqual(len(tvsModel.subModels), len(grid))
+
+        # Test the default value for option "persistSubModel" to be "true"
+        testSubPath = temp_path + "/testTrainValidationSplitSubModels"
+        savingPathWithSubModels = testSubPath + "cvModel3"
+        tvsModel.save(savingPathWithSubModels)
+        tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels)
+        self.assertEqual(len(tvsModel3.subModels), len(grid))
+        tvsModel4 = tvsModel3.copy()
+        self.assertEqual(len(tvsModel4.subModels), len(grid))
+
+        savingPathWithoutSubModels = testSubPath + "cvModel2"
+        tvsModel.write().option("persistSubModels", 
"false").save(savingPathWithoutSubModels)
+        tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels)
+        self.assertEqual(tvsModel2.subModels, None)
+
+        for i in range(len(grid)):
+            self.assertEqual(tvsModel.subModels[i].uid, 
tvsModel3.subModels[i].uid)
+
+    def test_save_load_nested_estimator(self):
+        # This tests saving and loading the trained model only.
+        # Save/load for TrainValidationSplit will be added later: SPARK-13786
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+        ova = OneVsRest(classifier=LogisticRegression())
+        lr1 = LogisticRegression().setMaxIter(100)
+        lr2 = LogisticRegression().setMaxIter(150)
+        grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
+        evaluator = MulticlassClassificationEvaluator()
+
+        tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, 
evaluator=evaluator)
+        tvsModel = tvs.fit(dataset)
+        tvsPath = temp_path + "/tvs"
+        tvs.save(tvsPath)
+        loadedTvs = TrainValidationSplit.load(tvsPath)
+        self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
+        self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
+
+        originalParamMap = tvs.getEstimatorParamMaps()
+        loadedParamMap = loadedTvs.getEstimatorParamMaps()
+        for i, param in enumerate(loadedParamMap):
+            for p in param:
+                if p.name == "classifier":
+                    self.assertEqual(param[p].uid, originalParamMap[i][p].uid)
+                else:
+                    self.assertEqual(param[p], originalParamMap[i][p])
+
+        tvsModelPath = temp_path + "/tvsModel"
+        tvsModel.save(tvsModelPath)
+        loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+        self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+
+    def test_copy(self):
+        dataset = self.spark.createDataFrame([
+            (10, 10.0),
+            (50, 50.0),
+            (100, 100.0),
+            (500, 500.0)] * 10,
+            ["feature", "label"])
+
+        iee = InducedErrorEstimator()
+        evaluator = RegressionEvaluator(metricName="r2")
+
+        grid = ParamGridBuilder() \
+            .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \
+            .build()
+        tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, 
evaluator=evaluator)
+        tvsModel = tvs.fit(dataset)
+        tvsCopied = tvs.copy()
+        tvsModelCopied = tvsModel.copy()
+
+        self.assertEqual(tvs.getEstimator().uid, tvsCopied.getEstimator().uid,
+                         "Copied TrainValidationSplit has the same uid of 
Estimator")
+
+        self.assertEqual(tvsModel.bestModel.uid, tvsModelCopied.bestModel.uid)
+        self.assertEqual(len(tvsModel.validationMetrics),
+                         len(tvsModelCopied.validationMetrics),
+                         "Copied validationMetrics has the same size of the 
original")
+        for index in range(len(tvsModel.validationMetrics)):
+            self.assertEqual(tvsModel.validationMetrics[index],
+                             tvsModelCopied.validationMetrics[index])
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_tuning import *
+
+    try:
+        import xmlrunner
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/034ae305/python/pyspark/ml/tests/test_wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests/test_wrapper.py 
b/python/pyspark/ml/tests/test_wrapper.py
new file mode 100644
index 0000000..4326d8e
--- /dev/null
+++ b/python/pyspark/ml/tests/test_wrapper.py
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+if sys.version_info[:2] <= (2, 6):
+    try:
+        import unittest2 as unittest
+    except ImportError:
+        sys.stderr.write('Please install unittest2 to test with Python 2.6 or 
earlier')
+        sys.exit(1)
+else:
+    import unittest
+
+import py4j
+
+from pyspark.ml.linalg import DenseVector, Vectors
+from pyspark.ml.regression import LinearRegression
+from pyspark.ml.wrapper import _java2py, _py2java, JavaParams, JavaWrapper
+from pyspark.testing.mllibutils import MLlibTestCase
+from pyspark.testing.mlutils import SparkSessionTestCase
+
+
+class JavaWrapperMemoryTests(SparkSessionTestCase):
+
+    def test_java_object_gets_detached(self):
+        df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+                                         (0.0, 2.0, Vectors.sparse(1, [], 
[]))],
+                                        ["label", "weight", "features"])
+        lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", 
weightCol="weight",
+                              fitIntercept=False)
+
+        model = lr.fit(df)
+        summary = model.summary
+
+        self.assertIsInstance(model, JavaWrapper)
+        self.assertIsInstance(summary, JavaWrapper)
+        self.assertIsInstance(model, JavaParams)
+        self.assertNotIsInstance(summary, JavaParams)
+
+        error_no_object = 'Target Object ID does not exist for this gateway'
+
+        self.assertIn("LinearRegression_", model._java_obj.toString())
+        self.assertIn("LinearRegressionTrainingSummary", 
summary._java_obj.toString())
+
+        model.__del__()
+
+        with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
+            model._java_obj.toString()
+        self.assertIn("LinearRegressionTrainingSummary", 
summary._java_obj.toString())
+
+        try:
+            summary.__del__()
+        except:
+            pass
+
+        with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
+            model._java_obj.toString()
+        with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
+            summary._java_obj.toString()
+
+
+class WrapperTests(MLlibTestCase):
+
+    def test_new_java_array(self):
+        # test array of strings
+        str_list = ["a", "b", "c"]
+        java_class = self.sc._gateway.jvm.java.lang.String
+        java_array = JavaWrapper._new_java_array(str_list, java_class)
+        self.assertEqual(_java2py(self.sc, java_array), str_list)
+        # test array of integers
+        int_list = [1, 2, 3]
+        java_class = self.sc._gateway.jvm.java.lang.Integer
+        java_array = JavaWrapper._new_java_array(int_list, java_class)
+        self.assertEqual(_java2py(self.sc, java_array), int_list)
+        # test array of floats
+        float_list = [0.1, 0.2, 0.3]
+        java_class = self.sc._gateway.jvm.java.lang.Double
+        java_array = JavaWrapper._new_java_array(float_list, java_class)
+        self.assertEqual(_java2py(self.sc, java_array), float_list)
+        # test array of bools
+        bool_list = [False, True, True]
+        java_class = self.sc._gateway.jvm.java.lang.Boolean
+        java_array = JavaWrapper._new_java_array(bool_list, java_class)
+        self.assertEqual(_java2py(self.sc, java_array), bool_list)
+        # test array of Java DenseVectors
+        v1 = DenseVector([0.0, 1.0])
+        v2 = DenseVector([1.0, 0.0])
+        vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)]
+        java_class = 
self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector
+        java_array = JavaWrapper._new_java_array(vec_java_list, java_class)
+        self.assertEqual(_java2py(self.sc, java_array), [v1, v2])
+        # test empty array
+        java_class = self.sc._gateway.jvm.java.lang.Integer
+        java_array = JavaWrapper._new_java_array([], java_class)
+        self.assertEqual(_java2py(self.sc, java_array), [])
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_wrapper import *
+
+    try:
+        import xmlrunner
+        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/034ae305/python/pyspark/testing/mlutils.py
----------------------------------------------------------------------
diff --git a/python/pyspark/testing/mlutils.py 
b/python/pyspark/testing/mlutils.py
new file mode 100644
index 0000000..12bf650
--- /dev/null
+++ b/python/pyspark/testing/mlutils.py
@@ -0,0 +1,161 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy as np
+
+from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer
+from pyspark.ml.param import Param, Params, TypeConverters
+from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
+from pyspark.ml.wrapper import _java2py
+from pyspark.sql import DataFrame, SparkSession
+from pyspark.sql.types import DoubleType
+from pyspark.testing.utils import ReusedPySparkTestCase as PySparkTestCase
+
+
+def check_params(test_self, py_stage, check_params_exist=True):
+    """
+    Checks common requirements for Params.params:
+      - set of params exist in Java and Python and are ordered by names
+      - param parent has the same UID as the object's UID
+      - default param value from Java matches value in Python
+      - optionally check if all params from Java also exist in Python
+    """
+    py_stage_str = "%s %s" % (type(py_stage), py_stage)
+    if not hasattr(py_stage, "_to_java"):
+        return
+    java_stage = py_stage._to_java()
+    if java_stage is None:
+        return
+    test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str)
+    if check_params_exist:
+        param_names = [p.name for p in py_stage.params]
+        java_params = list(java_stage.params())
+        java_param_names = [jp.name() for jp in java_params]
+        test_self.assertEqual(
+            param_names, sorted(java_param_names),
+            "Param list in Python does not match Java for %s:\nJava = 
%s\nPython = %s"
+            % (py_stage_str, java_param_names, param_names))
+    for p in py_stage.params:
+        test_self.assertEqual(p.parent, py_stage.uid)
+        java_param = java_stage.getParam(p.name)
+        py_has_default = py_stage.hasDefault(p)
+        java_has_default = java_stage.hasDefault(java_param)
+        test_self.assertEqual(py_has_default, java_has_default,
+                              "Default value mismatch of param %s for Params 
%s"
+                              % (p.name, str(py_stage)))
+        if py_has_default:
+            if p.name == "seed":
+                continue  # Random seeds between Spark and PySpark are 
different
+            java_default = _java2py(test_self.sc,
+                                    
java_stage.clear(java_param).getOrDefault(java_param))
+            py_stage._clear(p)
+            py_default = py_stage.getOrDefault(p)
+            # equality test for NaN is always False
+            if isinstance(java_default, float) and np.isnan(java_default):
+                java_default = "NaN"
+                py_default = "NaN" if np.isnan(py_default) else "not NaN"
+            test_self.assertEqual(
+                java_default, py_default,
+                "Java default %s != python default %s of param %s for Params 
%s"
+                % (str(java_default), str(py_default), p.name, str(py_stage)))
+
+
+class SparkSessionTestCase(PySparkTestCase):
+    @classmethod
+    def setUpClass(cls):
+        PySparkTestCase.setUpClass()
+        cls.spark = SparkSession(cls.sc)
+
+    @classmethod
+    def tearDownClass(cls):
+        PySparkTestCase.tearDownClass()
+        cls.spark.stop()
+
+
+class MockDataset(DataFrame):
+
+    def __init__(self):
+        self.index = 0
+
+
+class HasFake(Params):
+
+    def __init__(self):
+        super(HasFake, self).__init__()
+        self.fake = Param(self, "fake", "fake param")
+
+    def getFake(self):
+        return self.getOrDefault(self.fake)
+
+
+class MockTransformer(Transformer, HasFake):
+
+    def __init__(self):
+        super(MockTransformer, self).__init__()
+        self.dataset_index = None
+
+    def _transform(self, dataset):
+        self.dataset_index = dataset.index
+        dataset.index += 1
+        return dataset
+
+
+class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, 
DefaultParamsWritable):
+
+    shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
+                  "data in a DataFrame",
+                  typeConverter=TypeConverters.toFloat)
+
+    def __init__(self, shiftVal=1):
+        super(MockUnaryTransformer, self).__init__()
+        self._setDefault(shift=1)
+        self._set(shift=shiftVal)
+
+    def getShift(self):
+        return self.getOrDefault(self.shift)
+
+    def setShift(self, shift):
+        self._set(shift=shift)
+
+    def createTransformFunc(self):
+        shiftVal = self.getShift()
+        return lambda x: x + shiftVal
+
+    def outputDataType(self):
+        return DoubleType()
+
+    def validateInputType(self, inputType):
+        if inputType != DoubleType():
+            raise TypeError("Bad input type: {}. ".format(inputType) +
+                            "Requires Double.")
+
+
+class MockEstimator(Estimator, HasFake):
+
+    def __init__(self):
+        super(MockEstimator, self).__init__()
+        self.dataset_index = None
+
+    def _fit(self, dataset):
+        self.dataset_index = dataset.index
+        model = MockModel()
+        self._copyValues(model)
+        return model
+
+
+class MockModel(MockTransformer, Model, HasFake):
+    pass


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to