This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch fix_pipeline_tuning in repository https://gitbox.apache.org/repos/asf/spark.git
commit c834fe8f335dc74db6346d82b5ce4cf742cba9bb Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Mon Apr 20 17:04:12 2020 +0800 init --- python/pyspark/ml/pipeline.py | 46 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 09e0748..0004b64 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -25,8 +25,8 @@ from pyspark import since, keyword_only, SparkContext from pyspark.ml.base import Estimator, Model, Transformer from pyspark.ml.param import Param, Params from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaParams -from pyspark.ml.common import inherit_doc +from pyspark.ml.wrapper import JavaParams, JavaWrapper +from pyspark.ml.common import inherit_doc, _java2py, _py2java @inherit_doc @@ -174,6 +174,48 @@ class Pipeline(Estimator, MLReadable, MLWritable): return _java_obj + def _make_java_param_pair(self, param, value): + """ + Makes a Java param pair. + """ + sc = SparkContext._active_spark_context + param = self._resolveParam(param) + java_param = sc._jvm.org.apache.spark.ml.param.Param(param.parent, param.name, param.doc) + if isinstance(value, JavaParams): + # used in the case of an estimator having another estimator as a parameter + # the reason why this is not in _py2java in common.py is that importing + # Estimator and Model in common.py results in a circular import with inherit_doc + java_value = value._to_java() + else: + java_value = _py2java(sc, value) + return java_param.w(java_value) + + def _transfer_param_map_to_java(self, pyParamMap): + """ + Transforms a Python ParamMap into a Java ParamMap. + """ + paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") + for param in self.params: + if param in pyParamMap: + pair = self._make_java_param_pair(param, pyParamMap[param]) + paramMap.put([pair]) + return paramMap + + def _transfer_param_map_from_java(self, javaParamMap): + """ + Transforms a Java ParamMap into a Python ParamMap. + """ + sc = SparkContext._active_spark_context + paramMap = dict() + for pair in javaParamMap.toList(): + param = pair.param() + if self.hasParam(str(param.name())): + if param.name() == "classifier": + paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value()) + else: + paramMap[self.getParam(param.name())] = _java2py(sc, pair.value()) + return paramMap + @inherit_doc class PipelineWriter(MLWriter): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org