Repository: spark
Updated Branches:
  refs/heads/branch-2.1 97866e198 -> 20a432951


[SPARK-14772][PYTHON][ML] Fixed Params.copy method to match Scala implementation

## What changes were proposed in this pull request?
Fixed the PySpark Params.copy method to behave like the Scala implementation.  
The main issue was that it did not account for the _defaultParamMap and merged 
it into the explicitly created param map.

## How was this patch tested?
Added new unit test to verify the copy method behaves correctly for copying 
uid, explicitly created params, and default params.

Author: Bryan Cutler <cutl...@gmail.com>

Closes #17048 from BryanCutler/pyspark-ml-param_copy-Scala_sync-SPARK-14772-2_1.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/20a43295
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/20a43295
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/20a43295

Branch: refs/heads/branch-2.1
Commit: 20a432951c6281bb6d6bf9252ad5a352fef00424
Parents: 97866e1
Author: Bryan Cutler <cutl...@gmail.com>
Authored: Sat Feb 25 20:03:27 2017 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Sat Feb 25 20:03:27 2017 -0800

----------------------------------------------------------------------
 python/pyspark/ml/param/__init__.py | 17 +++++++++++------
 python/pyspark/ml/tests.py          | 16 ++++++++++++++++
 2 files changed, 27 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/20a43295/python/pyspark/ml/param/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/__init__.py 
b/python/pyspark/ml/param/__init__.py
index ade4864..205b8d5 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -385,6 +385,7 @@ class Params(Identifiable):
             extra = dict()
         that = copy.copy(self)
         that._paramMap = {}
+        that._defaultParamMap = {}
         return self._copyValues(that, extra)
 
     def _shouldOwn(self, param):
@@ -465,12 +466,16 @@ class Params(Identifiable):
         :param extra: extra params to be copied
         :return: the target instance with param values copied
         """
-        if extra is None:
-            extra = dict()
-        paramMap = self.extractParamMap(extra)
-        for p in self.params:
-            if p in paramMap and to.hasParam(p.name):
-                to._set(**{p.name: paramMap[p]})
+        paramMap = self._paramMap.copy()
+        if extra is not None:
+            paramMap.update(extra)
+        for param in self.params:
+            # copy default params
+            if param in self._defaultParamMap and to.hasParam(param.name):
+                to._defaultParamMap[to.getParam(param.name)] = 
self._defaultParamMap[param]
+            # copy explicitly set params
+            if param in paramMap and to.hasParam(param.name):
+                to._set(**{param.name: paramMap[param]})
         return to
 
     def _resetUid(self, newUid):

http://git-wip-us.apache.org/repos/asf/spark/blob/20a43295/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 68f5bc3..46be031 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -389,6 +389,22 @@ class ParamTests(PySparkTestCase):
         # Check windowSize is set properly
         self.assertEqual(model.getWindowSize(), 6)
 
+    def test_copy_param_extras(self):
+        tp = TestParams(seed=42)
+        extra = {tp.getParam(TestParams.inputCol.name): "copy_input"}
+        tp_copy = tp.copy(extra=extra)
+        self.assertEqual(tp.uid, tp_copy.uid)
+        self.assertEqual(tp.params, tp_copy.params)
+        for k, v in extra.items():
+            self.assertTrue(tp_copy.isDefined(k))
+            self.assertEqual(tp_copy.getOrDefault(k), v)
+        copied_no_extra = {}
+        for k, v in tp_copy._paramMap.items():
+            if k not in extra:
+                copied_no_extra[k] = v
+        self.assertEqual(tp._paramMap, copied_no_extra)
+        self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap)
+
 
 class EvaluatorTests(SparkSessionTestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to