Repository: spark Updated Branches: refs/heads/master 8b56f1664 -> cb43bbe13
[SPARK-21685][PYTHON][ML] PySpark Params isSet state should not change after transform ## What changes were proposed in this pull request? Currently when a PySpark Model is transformed, default params that have not been explicitly set are then set on the Java side on the call to `wrapper._transfer_values_to_java`. This incorrectly changes the state of the Param as it should still be marked as a default value only. ## How was this patch tested? Added a new test to verify that when transferring Params to Java, default params have their state preserved. Author: Bryan Cutler <cutl...@gmail.com> Closes #18982 from BryanCutler/pyspark-ml-param-to-java-defaults-SPARK-21685. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cb43bbe1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cb43bbe1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cb43bbe1 Branch: refs/heads/master Commit: cb43bbe13606673349511829fd71d1f34fc39c45 Parents: 8b56f16 Author: Bryan Cutler <cutl...@gmail.com> Authored: Fri Mar 23 11:42:40 2018 -0700 Committer: Holden Karau <hol...@pigscanfly.ca> Committed: Fri Mar 23 11:42:40 2018 -0700 ---------------------------------------------------------------------- python/pyspark/ml/tests.py | 20 +++++++++++++++++++- python/pyspark/ml/wrapper.py | 13 ++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/cb43bbe1/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index fd45fd0..0801199 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -369,7 +369,7 @@ class HasThrowableProperty(Params): raise RuntimeError("Test property to raise error when invoked") -class ParamTests(PySparkTestCase): +class ParamTests(SparkSessionTestCase): def test_copy_new_parent(self): testParams = TestParams() @@ -514,6 +514,24 @@ class ParamTests(PySparkTestCase): LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] ) + def test_preserve_set_state(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + self.assertFalse(binarizer.isSet("threshold")) + binarizer.transform(dataset) + binarizer._transfer_params_from_java() + self.assertFalse(binarizer.isSet("threshold"), + "Params not explicitly set should remain unset after transform") + + def test_default_params_transferred(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + # intentionally change the pyspark default, but don't set it + binarizer._defaultParamMap[binarizer.outputCol] = "my_default" + result = binarizer.transform(dataset).select("my_default").collect() + self.assertFalse(binarizer.isSet(binarizer.outputCol)) + self.assertEqual(result[0][0], 1.0) + @staticmethod def check_params(test_self, py_stage, check_params_exist=True): """ http://git-wip-us.apache.org/repos/asf/spark/blob/cb43bbe1/python/pyspark/ml/wrapper.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 5061f64..d325633 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -118,11 +118,18 @@ class JavaParams(JavaWrapper, Params): """ Transforms the embedded params to the companion Java object. """ - paramMap = self.extractParamMap() + pair_defaults = [] for param in self.params: - if param in paramMap: - pair = self._make_java_param_pair(param, paramMap[param]) + if self.isSet(param): + pair = self._make_java_param_pair(param, self._paramMap[param]) self._java_obj.set(pair) + if self.hasDefault(param): + pair = self._make_java_param_pair(param, self._defaultParamMap[param]) + pair_defaults.append(pair) + if len(pair_defaults) > 0: + sc = SparkContext._active_spark_context + pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults) + self._java_obj.setDefault(pair_defaults_seq) def _transfer_param_map_to_java(self, pyParamMap): """ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org