Repository: spark
Updated Branches:
  refs/heads/master 8b56f1664 -> cb43bbe13


[SPARK-21685][PYTHON][ML] PySpark Params isSet state should not change after 
transform

## What changes were proposed in this pull request?

Currently when a PySpark Model is transformed, default params that have not 
been explicitly set are then set on the Java side on the call to 
`wrapper._transfer_values_to_java`.  This incorrectly changes the state of the 
Param as it should still be marked as a default value only.

## How was this patch tested?

Added a new test to verify that when transferring Params to Java, default 
params have their state preserved.

Author: Bryan Cutler <cutl...@gmail.com>

Closes #18982 from BryanCutler/pyspark-ml-param-to-java-defaults-SPARK-21685.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cb43bbe1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cb43bbe1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cb43bbe1

Branch: refs/heads/master
Commit: cb43bbe13606673349511829fd71d1f34fc39c45
Parents: 8b56f16
Author: Bryan Cutler <cutl...@gmail.com>
Authored: Fri Mar 23 11:42:40 2018 -0700
Committer: Holden Karau <hol...@pigscanfly.ca>
Committed: Fri Mar 23 11:42:40 2018 -0700

----------------------------------------------------------------------
 python/pyspark/ml/tests.py   | 20 +++++++++++++++++++-
 python/pyspark/ml/wrapper.py | 13 ++++++++++---
 2 files changed, 29 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cb43bbe1/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index fd45fd0..0801199 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -369,7 +369,7 @@ class HasThrowableProperty(Params):
         raise RuntimeError("Test property to raise error when invoked")
 
 
-class ParamTests(PySparkTestCase):
+class ParamTests(SparkSessionTestCase):
 
     def test_copy_new_parent(self):
         testParams = TestParams()
@@ -514,6 +514,24 @@ class ParamTests(PySparkTestCase):
             LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
         )
 
+    def test_preserve_set_state(self):
+        dataset = self.spark.createDataFrame([(0.5,)], ["data"])
+        binarizer = Binarizer(inputCol="data")
+        self.assertFalse(binarizer.isSet("threshold"))
+        binarizer.transform(dataset)
+        binarizer._transfer_params_from_java()
+        self.assertFalse(binarizer.isSet("threshold"),
+                         "Params not explicitly set should remain unset after 
transform")
+
+    def test_default_params_transferred(self):
+        dataset = self.spark.createDataFrame([(0.5,)], ["data"])
+        binarizer = Binarizer(inputCol="data")
+        # intentionally change the pyspark default, but don't set it
+        binarizer._defaultParamMap[binarizer.outputCol] = "my_default"
+        result = binarizer.transform(dataset).select("my_default").collect()
+        self.assertFalse(binarizer.isSet(binarizer.outputCol))
+        self.assertEqual(result[0][0], 1.0)
+
     @staticmethod
     def check_params(test_self, py_stage, check_params_exist=True):
         """

http://git-wip-us.apache.org/repos/asf/spark/blob/cb43bbe1/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 5061f64..d325633 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -118,11 +118,18 @@ class JavaParams(JavaWrapper, Params):
         """
         Transforms the embedded params to the companion Java object.
         """
-        paramMap = self.extractParamMap()
+        pair_defaults = []
         for param in self.params:
-            if param in paramMap:
-                pair = self._make_java_param_pair(param, paramMap[param])
+            if self.isSet(param):
+                pair = self._make_java_param_pair(param, self._paramMap[param])
                 self._java_obj.set(pair)
+            if self.hasDefault(param):
+                pair = self._make_java_param_pair(param, 
self._defaultParamMap[param])
+                pair_defaults.append(pair)
+        if len(pair_defaults) > 0:
+            sc = SparkContext._active_spark_context
+            pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults)
+            self._java_obj.setDefault(pair_defaults_seq)
 
     def _transfer_param_map_to_java(self, pyParamMap):
         """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to