Repository: spark Updated Branches: refs/heads/master a3550e374 -> 130b8d07b
[SPARK-15008][ML][PYSPARK] Add integration test for OneVsRest ## What changes were proposed in this pull request? 1. Add `_transfer_param_map_to/from_java` for OneVsRest; 2. Add `_compare_params` in ml/tests.py to help compare params. 3. Add `test_onevsrest` as the integration test for OneVsRest. ## How was this patch tested? Python unit test. Author: yinxusen <yinxu...@gmail.com> Closes #12875 from yinxusen/SPARK-15008. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/130b8d07 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/130b8d07 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/130b8d07 Branch: refs/heads/master Commit: 130b8d07b8eb08f2ad522081a95032b90247094d Parents: a3550e3 Author: yinxusen <yinxu...@gmail.com> Authored: Fri May 27 13:18:29 2016 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Fri May 27 13:18:29 2016 -0700 ---------------------------------------------------------------------- python/pyspark/ml/tests.py | 69 +++++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 23 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/130b8d07/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index a7c93ac..4358175 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -747,12 +747,32 @@ class PersistenceTest(SparkSessionTestCase): except OSError: pass + def _compare_params(self, m1, m2, param): + """ + Compare 2 ML Params instances for the given param, and assert both have the same param value + and parent. The param must be a parameter of m1. + """ + # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap. + if m1.isDefined(param): + paramValue1 = m1.getOrDefault(param) + paramValue2 = m2.getOrDefault(m2.getParam(param.name)) + if isinstance(paramValue1, Params): + self._compare_pipelines(paramValue1, paramValue2) + else: + self.assertEqual(paramValue1, paramValue2) # for general types param + # Assert parents are equal + self.assertEqual(param.parent, m2.getParam(param.name).parent) + else: + # If m1 is not defined param, then m2 should not, too. See SPARK-14931. + self.assertFalse(m2.isDefined(m2.getParam(param.name))) + def _compare_pipelines(self, m1, m2): """ Compare 2 ML types, asserting that they are equivalent. This currently supports: - basic types - Pipeline, PipelineModel + - OneVsRest, OneVsRestModel This checks: - uid - type @@ -763,8 +783,7 @@ class PersistenceTest(SparkSessionTestCase): if isinstance(m1, JavaParams): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: - self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) - self.assertEqual(p.parent, m2.getParam(p.name).parent) + self._compare_params(m1, m2, p) elif isinstance(m1, Pipeline): self.assertEqual(len(m1.getStages()), len(m2.getStages())) for s1, s2 in zip(m1.getStages(), m2.getStages()): @@ -773,6 +792,13 @@ class PersistenceTest(SparkSessionTestCase): self.assertEqual(len(m1.stages), len(m2.stages)) for s1, s2 in zip(m1.stages, m2.stages): self._compare_pipelines(s1, s2) + elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel): + for p in m1.params: + self._compare_params(m1, m2, p) + if isinstance(m1, OneVsRestModel): + self.assertEqual(len(m1.models), len(m2.models)) + for x, y in zip(m1.models, m2.models): + self._compare_pipelines(x, y) else: raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) @@ -833,6 +859,24 @@ class PersistenceTest(SparkSessionTestCase): except OSError: pass + def test_onevsrest(self): + temp_path = tempfile.mkdtemp() + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))] * 10, + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr) + model = ovr.fit(df) + ovrPath = temp_path + "/ovr" + ovr.save(ovrPath) + loadedOvr = OneVsRest.load(ovrPath) + self._compare_pipelines(ovr, loadedOvr) + modelPath = temp_path + "/ovrModel" + model.save(modelPath) + loadedModel = OneVsRestModel.load(modelPath) + self._compare_pipelines(model, loadedModel) + def test_decisiontree_classifier(self): dt = DecisionTreeClassifier(maxDepth=1) path = tempfile.mkdtemp() @@ -1054,27 +1098,6 @@ class OneVsRestTests(SparkSessionTestCase): output = model.transform(df) self.assertEqual(output.columns, ["label", "features", "prediction"]) - def test_save_load(self): - temp_path = tempfile.mkdtemp() - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr) - model = ovr.fit(df) - ovrPath = temp_path + "/ovr" - ovr.save(ovrPath) - loadedOvr = OneVsRest.load(ovrPath) - self.assertEqual(loadedOvr.getFeaturesCol(), ovr.getFeaturesCol()) - self.assertEqual(loadedOvr.getLabelCol(), ovr.getLabelCol()) - self.assertEqual(loadedOvr.getClassifier().uid, ovr.getClassifier().uid) - modelPath = temp_path + "/ovrModel" - model.save(modelPath) - loadedModel = OneVsRestModel.load(modelPath) - for m, n in zip(model.models, loadedModel.models): - self.assertEqual(m.uid, n.uid) - class HashingTFTest(SparkSessionTestCase): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org