Repository: spark Updated Branches: refs/heads/master f5065abf4 -> e71cd96bf
[SPARK-15316][PYSPARK][ML] Add linkPredictionCol to GeneralizedLinearRegression ## What changes were proposed in this pull request? Add linkPredictionCol to GeneralizedLinearRegression and fix the PyDoc to generate the bullet list ## How was this patch tested? doctests & built docs locally Author: Holden Karau <hol...@us.ibm.com> Closes #13106 from holdenk/SPARK-15316-add-linkPredictionCol-toGeneralizedLinearRegression. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e71cd96b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e71cd96b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e71cd96b Branch: refs/heads/master Commit: e71cd96bf733f0440f818c6efc7a04b68d7cbe45 Parents: f5065ab Author: Holden Karau <hol...@us.ibm.com> Authored: Thu May 19 20:59:19 2016 +0200 Committer: Nick Pentreath <ni...@za.ibm.com> Committed: Thu May 19 20:59:19 2016 +0200 ---------------------------------------------------------------------- python/pyspark/ml/regression.py | 46 +++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e71cd96b/python/pyspark/ml/regression.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index cfcbbfc..25640b1 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1245,10 +1245,14 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha predictor (link function) and a description of the error distribution (family). It supports "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family is listed below. The first link function of each family is the default one. - - "gaussian" -> "identity", "log", "inverse" - - "binomial" -> "logit", "probit", "cloglog" - - "poisson" -> "log", "identity", "sqrt" - - "gamma" -> "inverse", "identity", "log" + + * "gaussian" -> "identity", "log", "inverse" + + * "binomial" -> "logit", "probit", "cloglog" + + * "poisson" -> "log", "identity", "sqrt" + + * "gamma" -> "inverse", "identity", "log" .. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_ @@ -1258,9 +1262,12 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha ... (1.0, Vectors.dense(1.0, 2.0)), ... (2.0, Vectors.dense(0.0, 0.0)), ... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"]) - >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity") + >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p") >>> model = glr.fit(df) - >>> abs(model.transform(df).head().prediction - 1.5) < 0.001 + >>> transformed = model.transform(df) + >>> abs(transformed.head().prediction - 1.5) < 0.001 + True + >>> abs(transformed.head().p - 1.5) < 0.001 True >>> model.coefficients DenseVector([1.5..., -1.0...]) @@ -1290,20 +1297,23 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha "relationship between the linear predictor and the mean of the distribution " + "function. Supported options: identity, log, inverse, logit, probit, cloglog " + "and sqrt.", typeConverter=TypeConverters.toString) + linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " + + "predictor) column name", typeConverter=TypeConverters.toString) @keyword_only def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, - regParam=0.0, weightCol=None, solver="irls"): + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=""): """ __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ - regParam=0.0, weightCol=None, solver="irls") + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol="") """ super(GeneralizedLinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) - self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") + self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls", + linkPredictionCol="") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1311,11 +1321,11 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha @since("2.0.0") def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, - regParam=0.0, weightCol=None, solver="irls"): + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=""): """ setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ - regParam=0.0, weightCol=None, solver="irls") + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol="") Sets params for generalized linear regression. """ kwargs = self.setParams._input_kwargs @@ -1339,6 +1349,20 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha return self.getOrDefault(self.family) @since("2.0.0") + def setLinkPredictionCol(self, value): + """ + Sets the value of :py:attr:`linkPredictionCol`. + """ + return self._set(linkPredictionCol=value) + + @since("2.0.0") + def getLinkPredictionCol(self): + """ + Gets the value of linkPredictionCol or its default value. + """ + return self.getOrDefault(self.linkPredictionCol) + + @since("2.0.0") def setLink(self, value): """ Sets the value of :py:attr:`link`. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org