Repository: spark
Updated Branches:
  refs/heads/branch-1.4 79983f17d -> f91bb57ef


[SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml

Otherwise, users can only use `transform` on the models. brkyvz

Author: Xiangrui Meng <[email protected]>

Closes #6156 from mengxr/SPARK-7647 and squashes the following commits:

1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in 
Python
f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel

(cherry picked from commit 723853edab18d28515af22097b76e4e6574b228e)
Signed-off-by: Xiangrui Meng <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f91bb57e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f91bb57e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f91bb57e

Branch: refs/heads/branch-1.4
Commit: f91bb57efae94c7a9e18eb30ec66bed93815f62f
Parents: 79983f1
Author: Xiangrui Meng <[email protected]>
Authored: Thu May 14 18:13:58 2015 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Thu May 14 18:14:07 2015 -0700

----------------------------------------------------------------------
 python/pyspark/ml/classification.py | 18 ++++++++++++++++++
 python/pyspark/ml/regression.py     | 18 ++++++++++++++++++
 python/pyspark/ml/wrapper.py        |  8 +++++++-
 3 files changed, 43 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f91bb57e/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index 96d2905..8c9a55e 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredicti
     >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
     >>> model.transform(test0).head().prediction
     0.0
+    >>> model.weights
+    DenseVector([5.5...])
+    >>> model.intercept
+    -2.68...
     >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], 
[1.0]))]).toDF()
     >>> model.transform(test1).head().prediction
     1.0
@@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel):
     Model fitted by LogisticRegression.
     """
 
+    @property
+    def weights(self):
+        """
+        Model weights.
+        """
+        return self._call_java("weights")
+
+    @property
+    def intercept(self):
+        """
+        Model intercept.
+        """
+        return self._call_java("intercept")
+
 
 class TreeClassifierParams(object):
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/f91bb57e/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 0ab5c6c..2803864 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPrediction
     >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], 
["features"])
     >>> model.transform(test0).head().prediction
     -1.0
+    >>> model.weights
+    DenseVector([1.0])
+    >>> model.intercept
+    0.0
     >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], 
["features"])
     >>> model.transform(test1).head().prediction
     1.0
@@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel):
     Model fitted by LinearRegression.
     """
 
+    @property
+    def weights(self):
+        """
+        Model weights.
+        """
+        return self._call_java("weights")
+
+    @property
+    def intercept(self):
+        """
+        Model intercept.
+        """
+        return self._call_java("intercept")
+
 
 class TreeRegressorParams(object):
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/f91bb57e/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index f5ac2a3..dda6c6a 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -21,7 +21,7 @@ from pyspark import SparkContext
 from pyspark.sql import DataFrame
 from pyspark.ml.param import Params
 from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
-from pyspark.mllib.common import inherit_doc
+from pyspark.mllib.common import inherit_doc, _java2py, _py2java
 
 
 def _jvm():
@@ -149,6 +149,12 @@ class JavaModel(Model, JavaTransformer):
     def _java_obj(self):
         return self._java_model
 
+    def _call_java(self, name, *args):
+        m = getattr(self._java_model, name)
+        sc = SparkContext._active_spark_context
+        java_args = [_py2java(sc, arg) for arg in args]
+        return _java2py(sc, m(*java_args))
+
 
 @inherit_doc
 class JavaEvaluator(Evaluator, JavaWrapper):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to