Repository: spark
Updated Branches:
  refs/heads/master 6b94420f6 -> 8a13c5096


[SPARK-24058][ML][PYSPARK] Default Params in ML should be saved separately: 
Python API

## What changes were proposed in this pull request?

See SPARK-23455 for reference. Now default params in ML are saved separately in 
metadata file in Scala. We must change it for Python for Spark 2.4.0 as well in 
order to keep them in sync.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #21153 from viirya/SPARK-24058.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8a13c509
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8a13c509
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8a13c509

Branch: refs/heads/master
Commit: 8a13c5096898f95d1dfcedaf5d31205a1cbf0a19
Parents: 6b94420
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Tue May 15 16:50:09 2018 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue May 15 16:50:09 2018 -0700

----------------------------------------------------------------------
 python/pyspark/ml/tests.py | 38 ++++++++++++++++++++++++++++++++++++++
 python/pyspark/ml/util.py  | 30 ++++++++++++++++++++++++++++--
 2 files changed, 66 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8a13c509/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 0935931..0dde0db 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1595,6 +1595,44 @@ class PersistenceTest(SparkSessionTestCase):
         self.assertEqual(lr.uid, lr3.uid)
         self.assertEqual(lr.extractParamMap(), lr3.extractParamMap())
 
+    def test_default_read_write_default_params(self):
+        lr = LogisticRegression()
+        self.assertFalse(lr.isSet(lr.getParam("threshold")))
+
+        lr.setMaxIter(50)
+        lr.setThreshold(.75)
+
+        # `threshold` is set by user, default param `predictionCol` is not set 
by user.
+        self.assertTrue(lr.isSet(lr.getParam("threshold")))
+        self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
+        self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
+
+        writer = DefaultParamsWriter(lr)
+        metadata = json.loads(writer._get_metadata_to_save(lr, self.sc))
+        self.assertTrue("defaultParamMap" in metadata)
+
+        reader = DefaultParamsReadable.read()
+        metadataStr = json.dumps(metadata, separators=[',',  ':'])
+        loadedMetadata = reader._parseMetaData(metadataStr, )
+        reader.getAndSetParams(lr, loadedMetadata)
+
+        self.assertTrue(lr.isSet(lr.getParam("threshold")))
+        self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
+        self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
+
+        # manually create metadata without `defaultParamMap` section.
+        del metadata['defaultParamMap']
+        metadataStr = json.dumps(metadata, separators=[',',  ':'])
+        loadedMetadata = reader._parseMetaData(metadataStr, )
+        with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` 
section not found"):
+            reader.getAndSetParams(lr, loadedMetadata)
+
+        # Prior to 2.4.0, metadata doesn't have `defaultParamMap`.
+        metadata['sparkVersion'] = '2.3.0'
+        metadataStr = json.dumps(metadata, separators=[',',  ':'])
+        loadedMetadata = reader._parseMetaData(metadataStr, )
+        reader.getAndSetParams(lr, loadedMetadata)
+
 
 class LDATest(SparkSessionTestCase):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8a13c509/python/pyspark/ml/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index a486c6a..9fa8566 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -30,6 +30,7 @@ if sys.version > '3':
 from pyspark import SparkContext, since
 from pyspark.ml.common import inherit_doc
 from pyspark.sql import SparkSession
+from pyspark.util import VersionUtils
 
 
 def _jvm():
@@ -396,6 +397,7 @@ class DefaultParamsWriter(MLWriter):
         - sparkVersion
         - uid
         - paramMap
+        - defaultParamMap (since 2.4.0)
         - (optionally, extra metadata)
         :param extraMetadata:  Extra metadata to be saved at same level as 
uid, paramMap, etc.
         :param paramMap:  If given, this is saved in the "paramMap" field.
@@ -417,15 +419,24 @@ class DefaultParamsWriter(MLWriter):
         """
         uid = instance.uid
         cls = instance.__module__ + '.' + instance.__class__.__name__
-        params = instance.extractParamMap()
+
+        # User-supplied param values
+        params = instance._paramMap
         jsonParams = {}
         if paramMap is not None:
             jsonParams = paramMap
         else:
             for p in params:
                 jsonParams[p.name] = params[p]
+
+        # Default param values
+        jsonDefaultParams = {}
+        for p in instance._defaultParamMap:
+            jsonDefaultParams[p.name] = instance._defaultParamMap[p]
+
         basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 
1000)),
-                         "sparkVersion": sc.version, "uid": uid, "paramMap": 
jsonParams}
+                         "sparkVersion": sc.version, "uid": uid, "paramMap": 
jsonParams,
+                         "defaultParamMap": jsonDefaultParams}
         if extraMetadata is not None:
             basicMetadata.update(extraMetadata)
         return json.dumps(basicMetadata, separators=[',',  ':'])
@@ -523,11 +534,26 @@ class DefaultParamsReader(MLReader):
         """
         Extract Params from metadata, and set them in the instance.
         """
+        # Set user-supplied param values
         for paramName in metadata['paramMap']:
             param = instance.getParam(paramName)
             paramValue = metadata['paramMap'][paramName]
             instance.set(param, paramValue)
 
+        # Set default param values
+        majorAndMinorVersions = 
VersionUtils.majorMinorVersion(metadata['sparkVersion'])
+        major = majorAndMinorVersions[0]
+        minor = majorAndMinorVersions[1]
+
+        # For metadata file prior to Spark 2.4, there is no default section.
+        if major > 2 or (major == 2 and minor >= 4):
+            assert 'defaultParamMap' in metadata, "Error loading metadata: 
Expected " + \
+                "`defaultParamMap` section not found"
+
+            for paramName in metadata['defaultParamMap']:
+                paramValue = metadata['defaultParamMap'][paramName]
+                instance._setDefault(**{paramName: paramValue})
+
     @staticmethod
     def loadParamsInstance(path, sc):
         """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to