Repository: spark Updated Branches: refs/heads/master ad3cc1312 -> 1a52a6237
[SPARK-20076][ML][PYSPARK] Add Python interface for ml.stats.Correlation ## What changes were proposed in this pull request? The Dataframes-based support for the correlation statistics is added in #17108. This patch adds the Python interface for it. ## How was this patch tested? Python unit test. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh <vii...@gmail.com> Closes #17494 from viirya/correlation-python-api. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1a52a623 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1a52a623 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1a52a623 Branch: refs/heads/master Commit: 1a52a62377a87cec493c8c6711bfd44e779c7973 Parents: ad3cc13 Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Fri Apr 7 11:00:10 2017 +0200 Committer: Nick Pentreath <ni...@za.ibm.com> Committed: Fri Apr 7 11:00:10 2017 +0200 ---------------------------------------------------------------------- .../org/apache/spark/ml/stat/Correlation.scala | 8 +-- python/pyspark/ml/stat.py | 61 ++++++++++++++++++++ 2 files changed, 65 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1a52a623/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala index d3c84b7..e185bc8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala @@ -38,7 +38,7 @@ object Correlation { /** * :: Experimental :: - * Compute the correlation matrix for the input RDD of Vectors using the specified method. + * Compute the correlation matrix for the input Dataset of Vectors using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * * @param dataset A dataset or a dataframe @@ -56,14 +56,14 @@ object Correlation { * Here is how to access the correlation coefficient: * {{{ * val data: Dataset[Vector] = ... - * val Row(coeff: Matrix) = Statistics.corr(data, "value").head + * val Row(coeff: Matrix) = Correlation.corr(data, "value").head * // coeff now contains the Pearson correlation matrix. * }}} * * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], - * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to - * avoid recomputing the common lineage. + * which is fairly costly. Cache the input Dataset before calling corr with `method = "spearman"` + * to avoid recomputing the common lineage. */ @Since("2.2.0") def corr(dataset: Dataset[_], column: String, method: String): DataFrame = { http://git-wip-us.apache.org/repos/asf/spark/blob/1a52a623/python/pyspark/ml/stat.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index db043ff..079b083 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -71,6 +71,67 @@ class ChiSquareTest(object): return _java2py(sc, javaTestObj.test(*args)) +class Correlation(object): + """ + .. note:: Experimental + + Compute the correlation matrix for the input dataset of Vectors using the specified method. + Methods currently supported: `pearson` (default), `spearman`. + + .. note:: For Spearman, a rank correlation, we need to create an RDD[Double] for each column + and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'` + to avoid recomputing the common lineage. + + :param dataset: + A dataset or a dataframe. + :param column: + The name of the column of vectors for which the correlation coefficient needs + to be computed. This must be a column of the dataset, and it must contain + Vector objects. + :param method: + String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman`. + :return: + A dataframe that contains the correlation matrix of the column of vectors. This + dataframe contains a single row and a single column of name + '$METHODNAME($COLUMN)'. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import Correlation + >>> dataset = [[Vectors.dense([1, 0, 0, -2])], + ... [Vectors.dense([4, 5, 0, 3])], + ... [Vectors.dense([6, 7, 0, 8])], + ... [Vectors.dense([9, 0, 0, 1])]] + >>> dataset = spark.createDataFrame(dataset, ['features']) + >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] + >>> print(str(pearsonCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], + [ 0.0556..., 1. , NaN, 0.9135...], + [ NaN, NaN, 1. , NaN], + [ 0.4004..., 0.9135..., NaN, 1. ]]) + >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] + >>> print(str(spearmanCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], + [ 0.1054..., 1. , NaN, 0.9486... ], + [ NaN, NaN, 1. , NaN], + [ 0.4 , 0.9486... , NaN, 1. ]]) + + .. versionadded:: 2.2.0 + + """ + @staticmethod + @since("2.2.0") + def corr(dataset, column, method="pearson"): + """ + Compute the correlation matrix with specified method using dataset. + """ + sc = SparkContext._active_spark_context + javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation + args = [_py2java(sc, arg) for arg in (dataset, column, method)] + return _java2py(sc, javaCorrObj.corr(*args)) + + if __name__ == "__main__": import doctest import pyspark.ml.stat --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org