Repository: spark Updated Branches: refs/heads/branch-1.4 79983f17d -> f91bb57ef
[SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml Otherwise, users can only use `transform` on the models. brkyvz Author: Xiangrui Meng <[email protected]> Closes #6156 from mengxr/SPARK-7647 and squashes the following commits: 1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in Python f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel (cherry picked from commit 723853edab18d28515af22097b76e4e6574b228e) Signed-off-by: Xiangrui Meng <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f91bb57e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f91bb57e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f91bb57e Branch: refs/heads/branch-1.4 Commit: f91bb57efae94c7a9e18eb30ec66bed93815f62f Parents: 79983f1 Author: Xiangrui Meng <[email protected]> Authored: Thu May 14 18:13:58 2015 -0700 Committer: Xiangrui Meng <[email protected]> Committed: Thu May 14 18:14:07 2015 -0700 ---------------------------------------------------------------------- python/pyspark/ml/classification.py | 18 ++++++++++++++++++ python/pyspark/ml/regression.py | 18 ++++++++++++++++++ python/pyspark/ml/wrapper.py | 8 +++++++- 3 files changed, 43 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f91bb57e/python/pyspark/ml/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 96d2905..8c9a55e 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> model.transform(test0).head().prediction 0.0 + >>> model.weights + DenseVector([5.5...]) + >>> model.intercept + -2.68... >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 @@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel): Model fitted by LogisticRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + class TreeClassifierParams(object): """ http://git-wip-us.apache.org/repos/asf/spark/blob/f91bb57e/python/pyspark/ml/regression.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0ab5c6c..2803864 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction -1.0 + >>> model.weights + DenseVector([1.0]) + >>> model.intercept + 0.0 >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 @@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel): Model fitted by LinearRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + class TreeRegressorParams(object): """ http://git-wip-us.apache.org/repos/asf/spark/blob/f91bb57e/python/pyspark/ml/wrapper.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index f5ac2a3..dda6c6a 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -21,7 +21,7 @@ from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model -from pyspark.mllib.common import inherit_doc +from pyspark.mllib.common import inherit_doc, _java2py, _py2java def _jvm(): @@ -149,6 +149,12 @@ class JavaModel(Model, JavaTransformer): def _java_obj(self): return self._java_model + def _call_java(self, name, *args): + m = getattr(self._java_model, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) + @inherit_doc class JavaEvaluator(Evaluator, JavaWrapper): --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
