Repository: spark
Updated Branches:
  refs/heads/master 781df4998 -> fc3cd2f50


[SPARK-14472][PYSPARK][ML] Cleanup ML JavaWrapper and related class hierarchy

Currently, JavaWrapper is only a wrapper class for pipeline classes that have 
Params and JavaCallable is a separate mixin that provides methods to make Java 
calls.  This change simplifies the class structure and to define the Java 
wrapper in a plain base class along with methods to make Java calls.  Also, 
renames Java wrapper classes to better reflect their purpose.

Ran existing Python ml tests and generated documentation to test this change.

Author: Bryan Cutler <cutl...@gmail.com>

Closes #12304 from BryanCutler/pyspark-cleanup-JavaWrapper-SPARK-14472.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fc3cd2f5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fc3cd2f5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fc3cd2f5

Branch: refs/heads/master
Commit: fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4
Parents: 781df49
Author: Bryan Cutler <cutl...@gmail.com>
Authored: Wed Apr 13 14:08:57 2016 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Wed Apr 13 14:08:57 2016 -0700

----------------------------------------------------------------------
 python/pyspark/ml/classification.py |  4 +-
 python/pyspark/ml/evaluation.py     |  4 +-
 python/pyspark/ml/pipeline.py       | 10 ++---
 python/pyspark/ml/regression.py     |  4 +-
 python/pyspark/ml/tests.py          |  4 +-
 python/pyspark/ml/tuning.py         | 26 +++++------
 python/pyspark/ml/util.py           |  4 +-
 python/pyspark/ml/wrapper.py        | 76 ++++++++++++++------------------
 8 files changed, 62 insertions(+), 70 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fc3cd2f5/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index e64c7a3..922f806 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -19,7 +19,7 @@ import warnings
 
 from pyspark import since
 from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
 from pyspark.ml.param import TypeConverters
 from pyspark.ml.param.shared import *
 from pyspark.ml.regression import (
@@ -272,7 +272,7 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, 
JavaMLReadable):
         return BinaryLogisticRegressionSummary(java_blr_summary)
 
 
-class LogisticRegressionSummary(JavaCallable):
+class LogisticRegressionSummary(JavaWrapper):
     """
     Abstraction for Logistic Regression Results for a given model.
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fc3cd2f5/python/pyspark/ml/evaluation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index c9b95b3..4b0bade 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -18,7 +18,7 @@
 from abc import abstractmethod, ABCMeta
 
 from pyspark import since
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
 from pyspark.ml.param import Param, Params
 from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, 
HasRawPredictionCol
 from pyspark.ml.util import keyword_only
@@ -81,7 +81,7 @@ class Evaluator(Params):
 
 
 @inherit_doc
-class JavaEvaluator(Evaluator, JavaWrapper):
+class JavaEvaluator(JavaParams, Evaluator):
     """
     Base class for :py:class:`Evaluator`s that wrap Java/Scala
     implementations.

http://git-wip-us.apache.org/repos/asf/spark/blob/fc3cd2f5/python/pyspark/ml/pipeline.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 2b5504b..9d654e8 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,7 +25,7 @@ from pyspark import since
 from pyspark.ml import Estimator, Model, Transformer
 from pyspark.ml.param import Param, Params
 from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, 
MLReadable, MLWritable
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
 from pyspark.mllib.common import inherit_doc
 
 
@@ -177,7 +177,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
         # Create a new instance of this stage.
         py_stage = cls()
         # Load information from java_stage to the instance.
-        py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()]
+        py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()]
         py_stage.setStages(py_stages)
         py_stage._resetUid(java_stage.uid())
         return py_stage
@@ -195,7 +195,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
         for idx, stage in enumerate(self.getStages()):
             java_stages[idx] = stage._to_java()
 
-        _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", 
self.uid)
+        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", 
self.uid)
         _java_obj.setStages(java_stages)
 
         return _java_obj
@@ -275,7 +275,7 @@ class PipelineModel(Model, MLReadable, MLWritable):
         Used for ML persistence.
         """
         # Load information from java_stage to the instance.
-        py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()]
+        py_stages = [JavaParams._from_java(s) for s in java_stage.stages()]
         # Create a new instance of this stage.
         py_stage = cls(py_stages)
         py_stage._resetUid(java_stage.uid())
@@ -295,6 +295,6 @@ class PipelineModel(Model, MLReadable, MLWritable):
             java_stages[idx] = stage._to_java()
 
         _java_obj =\
-            JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", 
self.uid, java_stages)
+            JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", 
self.uid, java_stages)
 
         return _java_obj

http://git-wip-us.apache.org/repos/asf/spark/blob/fc3cd2f5/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index bc88f88..316d7e3 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -20,7 +20,7 @@ import warnings
 from pyspark import since
 from pyspark.ml.param.shared import *
 from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
 from pyspark.mllib.common import inherit_doc
 from pyspark.sql import DataFrame
 
@@ -188,7 +188,7 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, 
JavaMLReadable):
         return LinearRegressionSummary(java_lr_summary)
 
 
-class LinearRegressionSummary(JavaCallable):
+class LinearRegressionSummary(JavaWrapper):
     """
     .. note:: Experimental
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fc3cd2f5/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 2dcd5ee..bcbeacb 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -52,7 +52,7 @@ from pyspark.ml.regression import LinearRegression, 
DecisionTreeRegressor
 from pyspark.ml.tuning import *
 from pyspark.ml.util import keyword_only
 from pyspark.ml.util import MLWritable, MLWriter
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
 from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
 from pyspark.sql import DataFrame, SQLContext, Row
 from pyspark.sql.functions import rand
@@ -644,7 +644,7 @@ class PersistenceTest(PySparkTestCase):
         """
         self.assertEqual(m1.uid, m2.uid)
         self.assertEqual(type(m1), type(m2))
-        if isinstance(m1, JavaWrapper):
+        if isinstance(m1, JavaParams):
             self.assertEqual(len(m1.params), len(m2.params))
             for p in m1.params:
                 self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))

http://git-wip-us.apache.org/repos/asf/spark/blob/fc3cd2f5/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index ea8c61b..456d79d 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model
 from pyspark.ml.param import Params, Param, TypeConverters
 from pyspark.ml.param.shared import HasSeed
 from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, 
MLReadable, MLWritable
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
 from pyspark.sql.functions import rand
 from pyspark.mllib.common import inherit_doc, _py2java
 
@@ -148,8 +148,8 @@ class ValidatorParams(HasSeed):
         """
 
         # Load information from java_stage to the instance.
-        estimator = JavaWrapper._from_java(java_stage.getEstimator())
-        evaluator = JavaWrapper._from_java(java_stage.getEvaluator())
+        estimator = JavaParams._from_java(java_stage.getEstimator())
+        evaluator = JavaParams._from_java(java_stage.getEvaluator())
         epms = [estimator._transfer_param_map_from_java(epm)
                 for epm in java_stage.getEstimatorParamMaps()]
         return estimator, epms, evaluator
@@ -329,7 +329,7 @@ class CrossValidator(Estimator, ValidatorParams, 
MLReadable, MLWritable):
 
         estimator, epms, evaluator = super(CrossValidator, 
self)._to_java_impl()
 
-        _java_obj = 
JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
+        _java_obj = 
JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
         _java_obj.setEstimatorParamMaps(epms)
         _java_obj.setEvaluator(evaluator)
         _java_obj.setEstimator(estimator)
@@ -393,7 +393,7 @@ class CrossValidatorModel(Model, ValidatorParams, 
MLReadable, MLWritable):
         """
 
         # Load information from java_stage to the instance.
-        bestModel = JavaWrapper._from_java(java_stage.bestModel())
+        bestModel = JavaParams._from_java(java_stage.bestModel())
         estimator, epms, evaluator = super(CrossValidatorModel, 
cls)._from_java_impl(java_stage)
         # Create a new instance of this stage.
         py_stage = cls(bestModel=bestModel)\
@@ -410,10 +410,10 @@ class CrossValidatorModel(Model, ValidatorParams, 
MLReadable, MLWritable):
 
         sc = SparkContext._active_spark_context
 
-        _java_obj = 
JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
-                                              self.uid,
-                                              self.bestModel._to_java(),
-                                              _py2java(sc, []))
+        _java_obj = 
JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
+                                             self.uid,
+                                             self.bestModel._to_java(),
+                                             _py2java(sc, []))
         estimator, epms, evaluator = super(CrossValidatorModel, 
self)._to_java_impl()
 
         _java_obj.set("evaluator", evaluator)
@@ -574,8 +574,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, 
MLReadable, MLWritable):
 
         estimator, epms, evaluator = super(TrainValidationSplit, 
self)._to_java_impl()
 
-        _java_obj = 
JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
-                                              self.uid)
+        _java_obj = 
JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
+                                             self.uid)
         _java_obj.setEstimatorParamMaps(epms)
         _java_obj.setEvaluator(evaluator)
         _java_obj.setEstimator(estimator)
@@ -639,7 +639,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, 
MLReadable, MLWritable):
         """
 
         # Load information from java_stage to the instance.
-        bestModel = JavaWrapper._from_java(java_stage.bestModel())
+        bestModel = JavaParams._from_java(java_stage.bestModel())
         estimator, epms, evaluator = \
             super(TrainValidationSplitModel, cls)._from_java_impl(java_stage)
         # Create a new instance of this stage.
@@ -657,7 +657,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, 
MLReadable, MLWritable):
 
         sc = SparkContext._active_spark_context
 
-        _java_obj = JavaWrapper._new_java_obj(
+        _java_obj = JavaParams._new_java_obj(
             "org.apache.spark.ml.tuning.TrainValidationSplitModel",
             self.uid,
             self.bestModel._to_java(),

http://git-wip-us.apache.org/repos/asf/spark/blob/fc3cd2f5/python/pyspark/ml/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index d4411fd..9dfcef0 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -99,7 +99,7 @@ class MLWriter(object):
 @inherit_doc
 class JavaMLWriter(MLWriter):
     """
-    (Private) Specialization of :py:class:`MLWriter` for 
:py:class:`JavaWrapper` types
+    (Private) Specialization of :py:class:`MLWriter` for 
:py:class:`JavaParams` types
     """
 
     def __init__(self, instance):
@@ -178,7 +178,7 @@ class MLReader(object):
 @inherit_doc
 class JavaMLReader(MLReader):
     """
-    (Private) Specialization of :py:class:`MLReader` for 
:py:class:`JavaWrapper` types
+    (Private) Specialization of :py:class:`MLReader` for 
:py:class:`JavaParams` types
     """
 
     def __init__(self, clazz):

http://git-wip-us.apache.org/repos/asf/spark/blob/fc3cd2f5/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index bbeb6cf..cd0e5b8 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -25,29 +25,32 @@ from pyspark.ml.util import _jvm
 from pyspark.mllib.common import inherit_doc, _java2py, _py2java
 
 
-@inherit_doc
-class JavaWrapper(Params):
+class JavaWrapper(object):
     """
-    Utility class to help create wrapper classes from Java/Scala
-    implementations of pipeline components.
+    Wrapper class for a Java companion object
     """
+    def __init__(self, java_obj=None):
+        super(JavaWrapper, self).__init__()
+        self._java_obj = java_obj
 
-    __metaclass__ = ABCMeta
-
-    def __init__(self):
+    @classmethod
+    def _create_from_java_class(cls, java_class, *args):
         """
-        Initialize the wrapped java object to None
+        Construct this object from given Java classname and arguments
         """
-        super(JavaWrapper, self).__init__()
-        #: The wrapped Java companion object. Subclasses should initialize
-        #: it properly. The param values in the Java object should be
-        #: synced with the Python wrapper in fit/transform/evaluate/copy.
-        self._java_obj = None
+        java_obj = JavaWrapper._new_java_obj(java_class, *args)
+        return cls(java_obj)
+
+    def _call_java(self, name, *args):
+        m = getattr(self._java_obj, name)
+        sc = SparkContext._active_spark_context
+        java_args = [_py2java(sc, arg) for arg in args]
+        return _java2py(sc, m(*java_args))
 
     @staticmethod
     def _new_java_obj(java_class, *args):
         """
-        Construct a new Java object.
+        Returns a new Java object.
         """
         sc = SparkContext._active_spark_context
         java_obj = _jvm()
@@ -56,6 +59,18 @@ class JavaWrapper(Params):
         java_args = [_py2java(sc, arg) for arg in args]
         return java_obj(*java_args)
 
+
+@inherit_doc
+class JavaParams(JavaWrapper, Params):
+    """
+    Utility class to help create wrapper classes from Java/Scala
+    implementations of pipeline components.
+    """
+    #: The param values in the Java object should be
+    #: synced with the Python wrapper in fit/transform/evaluate/copy.
+
+    __metaclass__ = ABCMeta
+
     def _make_java_param_pair(self, param, value):
         """
         Makes a Java parm pair.
@@ -151,7 +166,7 @@ class JavaWrapper(Params):
         stage_name = 
java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
         # Generate a default new instance from the stage_name class.
         py_type = __get_class(stage_name)
-        if issubclass(py_type, JavaWrapper):
+        if issubclass(py_type, JavaParams):
             # Load information from java_stage to the instance.
             py_stage = py_type()
             py_stage._java_obj = java_stage
@@ -166,7 +181,7 @@ class JavaWrapper(Params):
 
 
 @inherit_doc
-class JavaEstimator(Estimator, JavaWrapper):
+class JavaEstimator(JavaParams, Estimator):
     """
     Base class for :py:class:`Estimator`s that wrap Java/Scala
     implementations.
@@ -199,7 +214,7 @@ class JavaEstimator(Estimator, JavaWrapper):
 
 
 @inherit_doc
-class JavaTransformer(Transformer, JavaWrapper):
+class JavaTransformer(JavaParams, Transformer):
     """
     Base class for :py:class:`Transformer`s that wrap Java/Scala
     implementations. Subclasses should ensure they have the transformer Java 
object
@@ -213,30 +228,8 @@ class JavaTransformer(Transformer, JavaWrapper):
         return DataFrame(self._java_obj.transform(dataset._jdf), 
dataset.sql_ctx)
 
 
-class JavaCallable(object):
-    """
-    Wrapper for a plain object in JVM to make Java calls, can be used
-    as a mixin to another class that defines a _java_obj wrapper
-    """
-    def __init__(self, java_obj=None, sc=None):
-        super(JavaCallable, self).__init__()
-        self._sc = sc if sc is not None else SparkContext._active_spark_context
-        # if this class is a mixin and _java_obj is already defined then don't 
initialize
-        if java_obj is not None or not hasattr(self, "_java_obj"):
-            self._java_obj = java_obj
-
-    def __del__(self):
-        if self._java_obj is not None:
-            self._sc._gateway.detach(self._java_obj)
-
-    def _call_java(self, name, *args):
-        m = getattr(self._java_obj, name)
-        java_args = [_py2java(self._sc, arg) for arg in args]
-        return _java2py(self._sc, m(*java_args))
-
-
 @inherit_doc
-class JavaModel(Model, JavaCallable, JavaTransformer):
+class JavaModel(JavaTransformer, Model):
     """
     Base class for :py:class:`Model`s that wrap Java/Scala
     implementations. Subclasses should inherit this class before
@@ -259,9 +252,8 @@ class JavaModel(Model, JavaCallable, JavaTransformer):
         these wrappers depend on pyspark.ml.util (both directly and via
         other ML classes).
         """
-        super(JavaModel, self).__init__()
+        super(JavaModel, self).__init__(java_model)
         if java_model is not None:
-            self._java_obj = java_model
             self.uid = java_model.uid()
 
     def copy(self, extra=None):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to