Repository: spark Updated Branches: refs/heads/branch-2.1 97866e198 -> 20a432951
[SPARK-14772][PYTHON][ML] Fixed Params.copy method to match Scala implementation ## What changes were proposed in this pull request? Fixed the PySpark Params.copy method to behave like the Scala implementation. The main issue was that it did not account for the _defaultParamMap and merged it into the explicitly created param map. ## How was this patch tested? Added new unit test to verify the copy method behaves correctly for copying uid, explicitly created params, and default params. Author: Bryan Cutler <cutl...@gmail.com> Closes #17048 from BryanCutler/pyspark-ml-param_copy-Scala_sync-SPARK-14772-2_1. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/20a43295 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/20a43295 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/20a43295 Branch: refs/heads/branch-2.1 Commit: 20a432951c6281bb6d6bf9252ad5a352fef00424 Parents: 97866e1 Author: Bryan Cutler <cutl...@gmail.com> Authored: Sat Feb 25 20:03:27 2017 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Sat Feb 25 20:03:27 2017 -0800 ---------------------------------------------------------------------- python/pyspark/ml/param/__init__.py | 17 +++++++++++------ python/pyspark/ml/tests.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/20a43295/python/pyspark/ml/param/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index ade4864..205b8d5 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -385,6 +385,7 @@ class Params(Identifiable): extra = dict() that = copy.copy(self) that._paramMap = {} + that._defaultParamMap = {} return self._copyValues(that, extra) def _shouldOwn(self, param): @@ -465,12 +466,16 @@ class Params(Identifiable): :param extra: extra params to be copied :return: the target instance with param values copied """ - if extra is None: - extra = dict() - paramMap = self.extractParamMap(extra) - for p in self.params: - if p in paramMap and to.hasParam(p.name): - to._set(**{p.name: paramMap[p]}) + paramMap = self._paramMap.copy() + if extra is not None: + paramMap.update(extra) + for param in self.params: + # copy default params + if param in self._defaultParamMap and to.hasParam(param.name): + to._defaultParamMap[to.getParam(param.name)] = self._defaultParamMap[param] + # copy explicitly set params + if param in paramMap and to.hasParam(param.name): + to._set(**{param.name: paramMap[param]}) return to def _resetUid(self, newUid): http://git-wip-us.apache.org/repos/asf/spark/blob/20a43295/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 68f5bc3..46be031 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -389,6 +389,22 @@ class ParamTests(PySparkTestCase): # Check windowSize is set properly self.assertEqual(model.getWindowSize(), 6) + def test_copy_param_extras(self): + tp = TestParams(seed=42) + extra = {tp.getParam(TestParams.inputCol.name): "copy_input"} + tp_copy = tp.copy(extra=extra) + self.assertEqual(tp.uid, tp_copy.uid) + self.assertEqual(tp.params, tp_copy.params) + for k, v in extra.items(): + self.assertTrue(tp_copy.isDefined(k)) + self.assertEqual(tp_copy.getOrDefault(k), v) + copied_no_extra = {} + for k, v in tp_copy._paramMap.items(): + if k not in extra: + copied_no_extra[k] = v + self.assertEqual(tp._paramMap, copied_no_extra) + self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) + class EvaluatorTests(SparkSessionTestCase): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org