Repository: spark Updated Branches: refs/heads/master 068c3158a -> 1712a7c70
[SPARK-6093] [MLLIB] Add RegressionMetrics in PySpark/MLlib https://issues.apache.org/jira/browse/SPARK-6093 Author: Yanbo Liang <yblia...@gmail.com> Closes #5941 from yanboliang/spark-6093 and squashes the following commits: 6934af3 [Yanbo Liang] change to @property aac3bc5 [Yanbo Liang] Add RegressionMetrics in PySpark/MLlib Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1712a7c7 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1712a7c7 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1712a7c7 Branch: refs/heads/master Commit: 1712a7c7057bf6dd5da8aea1d7fbecdf96ea4b32 Parents: 068c315 Author: Yanbo Liang <yblia...@gmail.com> Authored: Thu May 7 11:18:32 2015 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Thu May 7 11:18:32 2015 -0700 ---------------------------------------------------------------------- .../mllib/evaluation/RegressionMetrics.scala | 9 +++ python/pyspark/mllib/evaluation.py | 78 +++++++++++++++++++- 2 files changed, 85 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1712a7c7/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 693117d..e577bf8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} +import org.apache.spark.sql.DataFrame /** * :: Experimental :: @@ -33,6 +34,14 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Multivariate class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { /** + * An auxiliary constructor taking a DataFrame. + * @param predictionAndObservations a DataFrame with two double columns: + * prediction and observation + */ + private[mllib] def this(predictionAndObservations: DataFrame) = + this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1)))) + + /** * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. */ private lazy val summary: MultivariateStatisticalSummary = { http://git-wip-us.apache.org/repos/asf/spark/blob/1712a7c7/python/pyspark/mllib/evaluation.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 16cb49c..3e11df0 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -27,9 +27,9 @@ class BinaryClassificationMetrics(JavaModelWrapper): >>> scoreAndLabels = sc.parallelize([ ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2) >>> metrics = BinaryClassificationMetrics(scoreAndLabels) - >>> metrics.areaUnderROC() + >>> metrics.areaUnderROC 0.70... - >>> metrics.areaUnderPR() + >>> metrics.areaUnderPR 0.83... >>> metrics.unpersist() """ @@ -47,6 +47,7 @@ class BinaryClassificationMetrics(JavaModelWrapper): java_model = java_class(df._jdf) super(BinaryClassificationMetrics, self).__init__(java_model) + @property def areaUnderROC(self): """ Computes the area under the receiver operating characteristic @@ -54,6 +55,7 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ return self.call("areaUnderROC") + @property def areaUnderPR(self): """ Computes the area under the precision-recall curve. @@ -67,6 +69,78 @@ class BinaryClassificationMetrics(JavaModelWrapper): self.call("unpersist") +class RegressionMetrics(JavaModelWrapper): + """ + Evaluator for regression. + + >>> predictionAndObservations = sc.parallelize([ + ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)]) + >>> metrics = RegressionMetrics(predictionAndObservations) + >>> metrics.explainedVariance + 0.95... + >>> metrics.meanAbsoluteError + 0.5... + >>> metrics.meanSquaredError + 0.37... + >>> metrics.rootMeanSquaredError + 0.61... + >>> metrics.r2 + 0.94... + """ + + def __init__(self, predictionAndObservations): + """ + :param predictionAndObservations: an RDD of (prediction, observation) pairs. + """ + sc = predictionAndObservations.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([ + StructField("prediction", DoubleType(), nullable=False), + StructField("observation", DoubleType(), nullable=False)])) + java_class = sc._jvm.org.apache.spark.mllib.evaluation.RegressionMetrics + java_model = java_class(df._jdf) + super(RegressionMetrics, self).__init__(java_model) + + @property + def explainedVariance(self): + """ + Returns the explained variance regression score. + explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + """ + return self.call("explainedVariance") + + @property + def meanAbsoluteError(self): + """ + Returns the mean absolute error, which is a risk function corresponding to the + expected value of the absolute error loss or l1-norm loss. + """ + return self.call("meanAbsoluteError") + + @property + def meanSquaredError(self): + """ + Returns the mean squared error, which is a risk function corresponding to the + expected value of the squared error loss or quadratic loss. + """ + return self.call("meanSquaredError") + + @property + def rootMeanSquaredError(self): + """ + Returns the root mean squared error, which is defined as the square root of + the mean squared error. + """ + return self.call("rootMeanSquaredError") + + @property + def r2(self): + """ + Returns R^2^, the coefficient of determination. + """ + return self.call("r2") + + def _test(): import doctest from pyspark import SparkContext --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org