Github user WeichenXu123 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/18742#discussion_r130521214
  
    --- Diff: python/pyspark/ml/util.py ---
    @@ -283,3 +289,124 @@ def numFeatures(self):
             Returns the number of features the model was trained on. If 
unknown, returns -1
             """
             return self._call_java("numFeatures")
    +
    +
    +@inherit_doc
    +class DefaultParamsWritable(MLWritable):
    +
    +    # overrides the write() function in MLWriteable
    +    # users call .save() in MLWriteable which calls this write() function 
and then calls
    +    # the .save() in DefaultParamsWriter
    +    # this can be overridden to return a different Writer (ex. 
OneVsRestWriter as seen in Scala)
    +    def write(self):
    +        # instance of check for params?
    +        return DefaultParamsWriter(self)
    +
    +
    +@inherit_doc
    +class DefaultParamsWriter(MLWriter):
    +
    +    def __init__(self, instance):
    +        super(DefaultParamsWriter, self).__init__()
    +        self.instance = instance
    +        self.sc = SparkContext._active_spark_context
    +
    +    # if a model extends DefaultParamsWriteable this save() function is 
called
    +    def save(self, path):
    +        if self.shouldOverwrite:
    +            # This command removes a file. Is this enough?
    +            os.remove(path)
    +        DefaultParamsWriter.save_metadata(self.instance, path, self.sc)
    +
    +    def overwrite(self):
    +        self.shouldOverwrite = True
    +        return self
    +
    +    @staticmethod
    +    def save_metadata(instance, path, sc, extraMetadata=None, 
paramMap=None):
    +        metadataPath = os.path.join(path, "metadata")
    +        metadataJson = DefaultParamsWriter.get_metadata_to_save(instance,
    +                                                                
metadataPath,
    +                                                                sc,
    +                                                                
extraMetadata,
    +                                                                paramMap)
    +        sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath)
    +
    +    @staticmethod
    +    def get_metadata_to_save(instance, path, sc, extraMetadata=None, 
paramMap=None):
    +        uid = instance.uid
    +        cls = instance.__module__ + '.' + instance.__class__.__name__
    +        params = instance.extractParamMap()
    +        jsonParams = {}
    +        if paramMap is not None:
    +            for p in paramMap:
    +                jsonParams[p.name] = paramMap[p]
    +        else:
    +            for p in params:
    +                jsonParams[p.name] = params[p]
    +        basicMetadata = {"class": cls, "timestamp": int(round(time.time() 
* 1000)),
    +                         "sparkVersion": sc.version, "uid": uid, 
"paramMap": jsonParams}
    --- End diff --
    
    Maybe we should use `long(round(time.time() * 1000))` ?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to