Repository: spark
Updated Branches:
  refs/heads/master 1fa58868b -> 81303f7ca


[SPARK-19806][ML][PYSPARK] PySpark GeneralizedLinearRegression supports tweedie 
distribution.

## What changes were proposed in this pull request?
PySpark ```GeneralizedLinearRegression``` supports tweedie distribution.

## How was this patch tested?
Add unit tests.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #17146 from yanboliang/spark-19806.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/81303f7c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/81303f7c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/81303f7c

Branch: refs/heads/master
Commit: 81303f7ca7808d51229411dce8feeed8c23dbe15
Parents: 1fa5886
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Wed Mar 8 02:09:36 2017 -0800
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Wed Mar 8 02:09:36 2017 -0800

----------------------------------------------------------------------
 .../GeneralizedLinearRegression.scala           |  8 +--
 python/pyspark/ml/regression.py                 | 61 +++++++++++++++++---
 python/pyspark/ml/tests.py                      | 20 +++++++
 3 files changed, 77 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/81303f7c/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 110764d..3be8b53 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -66,7 +66,7 @@ private[regression] trait GeneralizedLinearRegressionBase 
extends PredictorParam
   /**
    * Param for the power in the variance function of the Tweedie distribution 
which provides
    * the relationship between the variance and mean of the distribution.
-   * Only applicable for the Tweedie family.
+   * Only applicable to the Tweedie family.
    * (see <a href="https://en.wikipedia.org/wiki/Tweedie_distribution";>
    * Tweedie Distribution (Wikipedia)</a>)
    * Supported values: 0 and [1, Inf).
@@ -79,7 +79,7 @@ private[regression] trait GeneralizedLinearRegressionBase 
extends PredictorParam
   final val variancePower: DoubleParam = new DoubleParam(this, "variancePower",
     "The power in the variance function of the Tweedie distribution which 
characterizes " +
     "the relationship between the variance and mean of the distribution. " +
-    "Only applicable for the Tweedie family. Supported values: 0 and [1, 
Inf).",
+    "Only applicable to the Tweedie family. Supported values: 0 and [1, Inf).",
     (x: Double) => x >= 1.0 || x == 0.0)
 
   /** @group getParam */
@@ -106,7 +106,7 @@ private[regression] trait GeneralizedLinearRegressionBase 
extends PredictorParam
   def getLink: String = $(link)
 
   /**
-   * Param for the index in the power link function. Only applicable for the 
Tweedie family.
+   * Param for the index in the power link function. Only applicable to the 
Tweedie family.
    * Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, 
Inverse or Sqrt
    * link, respectively.
    * When not set, this value defaults to 1 - [[variancePower]], which matches 
the R "statmod"
@@ -116,7 +116,7 @@ private[regression] trait GeneralizedLinearRegressionBase 
extends PredictorParam
    */
   @Since("2.2.0")
   final val linkPower: DoubleParam = new DoubleParam(this, "linkPower",
-    "The index in the power link function. Only applicable for the Tweedie 
family.")
+    "The index in the power link function. Only applicable to the Tweedie 
family.")
 
   /** @group getParam */
   @Since("2.2.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/81303f7c/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index b199bf2..3c3fcc8 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -1294,8 +1294,8 @@ class GeneralizedLinearRegression(JavaEstimator, 
HasLabelCol, HasFeaturesCol, Ha
 
     Fit a Generalized Linear Model specified by giving a symbolic description 
of the linear
     predictor (link function) and a description of the error distribution 
(family). It supports
-    "gaussian", "binomial", "poisson" and "gamma" as family. Valid link 
functions for each family
-    is listed below. The first link function of each family is the default one.
+    "gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid 
link functions for
+    each family is listed below. The first link function of each family is the 
default one.
 
     * "gaussian" -> "identity", "log", "inverse"
 
@@ -1305,6 +1305,9 @@ class GeneralizedLinearRegression(JavaEstimator, 
HasLabelCol, HasFeaturesCol, Ha
 
     * "gamma"    -> "inverse", "identity", "log"
 
+    * "tweedie"  -> power link function specified through "linkPower". \
+                    The default link power in the tweedie family is 1 - 
variancePower.
+
     .. seealso:: `GLM 
<https://en.wikipedia.org/wiki/Generalized_linear_model>`_
 
     >>> from pyspark.ml.linalg import Vectors
@@ -1344,7 +1347,7 @@ class GeneralizedLinearRegression(JavaEstimator, 
HasLabelCol, HasFeaturesCol, Ha
 
     family = Param(Params._dummy(), "family", "The name of family which is a 
description of " +
                    "the error distribution to be used in the model. Supported 
options: " +
-                   "gaussian (default), binomial, poisson and gamma.",
+                   "gaussian (default), binomial, poisson, gamma and tweedie.",
                    typeConverter=TypeConverters.toString)
     link = Param(Params._dummy(), "link", "The name of link function which 
provides the " +
                  "relationship between the linear predictor and the mean of 
the distribution " +
@@ -1352,32 +1355,46 @@ class GeneralizedLinearRegression(JavaEstimator, 
HasLabelCol, HasFeaturesCol, Ha
                  "and sqrt.", typeConverter=TypeConverters.toString)
     linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link 
prediction (linear " +
                               "predictor) column name", 
typeConverter=TypeConverters.toString)
+    variancePower = Param(Params._dummy(), "variancePower", "The power in the 
variance function " +
+                          "of the Tweedie distribution which characterizes the 
relationship " +
+                          "between the variance and mean of the distribution. 
Only applicable " +
+                          "for the Tweedie family. Supported values: 0 and [1, 
Inf).",
+                          typeConverter=TypeConverters.toFloat)
+    linkPower = Param(Params._dummy(), "linkPower", "The index in the power 
link function. " +
+                      "Only applicable to the Tweedie family.",
+                      typeConverter=TypeConverters.toFloat)
 
     @keyword_only
     def __init__(self, labelCol="label", featuresCol="features", 
predictionCol="prediction",
                  family="gaussian", link=None, fitIntercept=True, maxIter=25, 
tol=1e-6,
-                 regParam=0.0, weightCol=None, solver="irls", 
linkPredictionCol=None):
+                 regParam=0.0, weightCol=None, solver="irls", 
linkPredictionCol=None,
+                 variancePower=0.0, linkPower=None):
         """
         __init__(self, labelCol="label", featuresCol="features", 
predictionCol="prediction", \
                  family="gaussian", link=None, fitIntercept=True, maxIter=25, 
tol=1e-6, \
-                 regParam=0.0, weightCol=None, solver="irls", 
linkPredictionCol=None)
+                 regParam=0.0, weightCol=None, solver="irls", 
linkPredictionCol=None, \
+                 variancePower=0.0, linkPower=None)
         """
         super(GeneralizedLinearRegression, self).__init__()
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.regression.GeneralizedLinearRegression", 
self.uid)
-        self._setDefault(family="gaussian", maxIter=25, tol=1e-6, 
regParam=0.0, solver="irls")
+        self._setDefault(family="gaussian", maxIter=25, tol=1e-6, 
regParam=0.0, solver="irls",
+                         variancePower=0.0)
         kwargs = self._input_kwargs
+
         self.setParams(**kwargs)
 
     @keyword_only
     @since("2.0.0")
     def setParams(self, labelCol="label", featuresCol="features", 
predictionCol="prediction",
                   family="gaussian", link=None, fitIntercept=True, maxIter=25, 
tol=1e-6,
-                  regParam=0.0, weightCol=None, solver="irls", 
linkPredictionCol=None):
+                  regParam=0.0, weightCol=None, solver="irls", 
linkPredictionCol=None,
+                  variancePower=0.0, linkPower=None):
         """
         setParams(self, labelCol="label", featuresCol="features", 
predictionCol="prediction", \
                   family="gaussian", link=None, fitIntercept=True, maxIter=25, 
tol=1e-6, \
-                  regParam=0.0, weightCol=None, solver="irls", 
linkPredictionCol=None)
+                  regParam=0.0, weightCol=None, solver="irls", 
linkPredictionCol=None, \
+                  variancePower=0.0, linkPower=None)
         Sets params for generalized linear regression.
         """
         kwargs = self._input_kwargs
@@ -1428,6 +1445,34 @@ class GeneralizedLinearRegression(JavaEstimator, 
HasLabelCol, HasFeaturesCol, Ha
         """
         return self.getOrDefault(self.link)
 
+    @since("2.2.0")
+    def setVariancePower(self, value):
+        """
+        Sets the value of :py:attr:`variancePower`.
+        """
+        return self._set(variancePower=value)
+
+    @since("2.2.0")
+    def getVariancePower(self):
+        """
+        Gets the value of variancePower or its default value.
+        """
+        return self.getOrDefault(self.variancePower)
+
+    @since("2.2.0")
+    def setLinkPower(self, value):
+        """
+        Sets the value of :py:attr:`linkPower`.
+        """
+        return self._set(linkPower=value)
+
+    @since("2.2.0")
+    def getLinkPower(self):
+        """
+        Gets the value of linkPower or its default value.
+        """
+        return self.getOrDefault(self.linkPower)
+
 
 class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, 
JavaMLWritable,
                                        JavaMLReadable):

http://git-wip-us.apache.org/repos/asf/spark/blob/81303f7c/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 3524160..f052f5b 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1223,6 +1223,26 @@ class HashingTFTest(SparkSessionTestCase):
                                    ": expected " + str(expected[i]) + ", got " 
+ str(features[i]))
 
 
+class GeneralizedLinearRegressionTest(SparkSessionTestCase):
+
+    def test_tweedie_distribution(self):
+
+        df = self.spark.createDataFrame(
+            [(1.0, Vectors.dense(0.0, 0.0)),
+             (1.0, Vectors.dense(1.0, 2.0)),
+             (2.0, Vectors.dense(0.0, 0.0)),
+             (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"])
+
+        glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6)
+        model = glr.fit(df)
+        self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 
0.3402], atol=1E-4))
+        self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4))
+
+        model2 = glr.setLinkPower(-1.0).fit(df)
+        self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 
0.5], atol=1E-4))
+        self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4))
+
+
 class ALSTest(SparkSessionTestCase):
 
     def test_storage_levels(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to