Repository: spark
Updated Branches:
  refs/heads/master a3550e374 -> 130b8d07b


[SPARK-15008][ML][PYSPARK] Add integration test for OneVsRest

## What changes were proposed in this pull request?

1. Add `_transfer_param_map_to/from_java` for OneVsRest;

2. Add `_compare_params` in ml/tests.py to help compare params.

3. Add `test_onevsrest` as the integration test for OneVsRest.

## How was this patch tested?

Python unit test.

Author: yinxusen <yinxu...@gmail.com>

Closes #12875 from yinxusen/SPARK-15008.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/130b8d07
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/130b8d07
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/130b8d07

Branch: refs/heads/master
Commit: 130b8d07b8eb08f2ad522081a95032b90247094d
Parents: a3550e3
Author: yinxusen <yinxu...@gmail.com>
Authored: Fri May 27 13:18:29 2016 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Fri May 27 13:18:29 2016 -0700

----------------------------------------------------------------------
 python/pyspark/ml/tests.py | 69 +++++++++++++++++++++++++++--------------
 1 file changed, 46 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/130b8d07/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index a7c93ac..4358175 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -747,12 +747,32 @@ class PersistenceTest(SparkSessionTestCase):
         except OSError:
             pass
 
+    def _compare_params(self, m1, m2, param):
+        """
+        Compare 2 ML Params instances for the given param, and assert both 
have the same param value
+        and parent. The param must be a parameter of m1.
+        """
+        # Prevent key not found error in case of some param in neither 
paramMap nor defaultParamMap.
+        if m1.isDefined(param):
+            paramValue1 = m1.getOrDefault(param)
+            paramValue2 = m2.getOrDefault(m2.getParam(param.name))
+            if isinstance(paramValue1, Params):
+                self._compare_pipelines(paramValue1, paramValue2)
+            else:
+                self.assertEqual(paramValue1, paramValue2)  # for general 
types param
+            # Assert parents are equal
+            self.assertEqual(param.parent, m2.getParam(param.name).parent)
+        else:
+            # If m1 is not defined param, then m2 should not, too. See 
SPARK-14931.
+            self.assertFalse(m2.isDefined(m2.getParam(param.name)))
+
     def _compare_pipelines(self, m1, m2):
         """
         Compare 2 ML types, asserting that they are equivalent.
         This currently supports:
          - basic types
          - Pipeline, PipelineModel
+         - OneVsRest, OneVsRestModel
         This checks:
          - uid
          - type
@@ -763,8 +783,7 @@ class PersistenceTest(SparkSessionTestCase):
         if isinstance(m1, JavaParams):
             self.assertEqual(len(m1.params), len(m2.params))
             for p in m1.params:
-                self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
-                self.assertEqual(p.parent, m2.getParam(p.name).parent)
+                self._compare_params(m1, m2, p)
         elif isinstance(m1, Pipeline):
             self.assertEqual(len(m1.getStages()), len(m2.getStages()))
             for s1, s2 in zip(m1.getStages(), m2.getStages()):
@@ -773,6 +792,13 @@ class PersistenceTest(SparkSessionTestCase):
             self.assertEqual(len(m1.stages), len(m2.stages))
             for s1, s2 in zip(m1.stages, m2.stages):
                 self._compare_pipelines(s1, s2)
+        elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel):
+            for p in m1.params:
+                self._compare_params(m1, m2, p)
+            if isinstance(m1, OneVsRestModel):
+                self.assertEqual(len(m1.models), len(m2.models))
+                for x, y in zip(m1.models, m2.models):
+                    self._compare_pipelines(x, y)
         else:
             raise RuntimeError("_compare_pipelines does not yet support type: 
%s" % type(m1))
 
@@ -833,6 +859,24 @@ class PersistenceTest(SparkSessionTestCase):
             except OSError:
                 pass
 
+    def test_onevsrest(self):
+        temp_path = tempfile.mkdtemp()
+        df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+                                         (1.0, Vectors.sparse(2, [], [])),
+                                         (2.0, Vectors.dense(0.5, 0.5))] * 10,
+                                        ["label", "features"])
+        lr = LogisticRegression(maxIter=5, regParam=0.01)
+        ovr = OneVsRest(classifier=lr)
+        model = ovr.fit(df)
+        ovrPath = temp_path + "/ovr"
+        ovr.save(ovrPath)
+        loadedOvr = OneVsRest.load(ovrPath)
+        self._compare_pipelines(ovr, loadedOvr)
+        modelPath = temp_path + "/ovrModel"
+        model.save(modelPath)
+        loadedModel = OneVsRestModel.load(modelPath)
+        self._compare_pipelines(model, loadedModel)
+
     def test_decisiontree_classifier(self):
         dt = DecisionTreeClassifier(maxDepth=1)
         path = tempfile.mkdtemp()
@@ -1054,27 +1098,6 @@ class OneVsRestTests(SparkSessionTestCase):
         output = model.transform(df)
         self.assertEqual(output.columns, ["label", "features", "prediction"])
 
-    def test_save_load(self):
-        temp_path = tempfile.mkdtemp()
-        df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
-                                         (1.0, Vectors.sparse(2, [], [])),
-                                         (2.0, Vectors.dense(0.5, 0.5))],
-                                        ["label", "features"])
-        lr = LogisticRegression(maxIter=5, regParam=0.01)
-        ovr = OneVsRest(classifier=lr)
-        model = ovr.fit(df)
-        ovrPath = temp_path + "/ovr"
-        ovr.save(ovrPath)
-        loadedOvr = OneVsRest.load(ovrPath)
-        self.assertEqual(loadedOvr.getFeaturesCol(), ovr.getFeaturesCol())
-        self.assertEqual(loadedOvr.getLabelCol(), ovr.getLabelCol())
-        self.assertEqual(loadedOvr.getClassifier().uid, 
ovr.getClassifier().uid)
-        modelPath = temp_path + "/ovrModel"
-        model.save(modelPath)
-        loadedModel = OneVsRestModel.load(modelPath)
-        for m, n in zip(model.models, loadedModel.models):
-            self.assertEqual(m.uid, n.uid)
-
 
 class HashingTFTest(SparkSessionTestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to