Repository: spark
Updated Branches:
  refs/heads/master ad3cc1312 -> 1a52a6237


[SPARK-20076][ML][PYSPARK] Add Python interface for ml.stats.Correlation

## What changes were proposed in this pull request?

The Dataframes-based support for the correlation statistics is added in #17108. 
This patch adds the Python interface for it.

## How was this patch tested?

Python unit test.

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #17494 from viirya/correlation-python-api.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1a52a623
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1a52a623
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1a52a623

Branch: refs/heads/master
Commit: 1a52a62377a87cec493c8c6711bfd44e779c7973
Parents: ad3cc13
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Fri Apr 7 11:00:10 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Fri Apr 7 11:00:10 2017 +0200

----------------------------------------------------------------------
 .../org/apache/spark/ml/stat/Correlation.scala  |  8 +--
 python/pyspark/ml/stat.py                       | 61 ++++++++++++++++++++
 2 files changed, 65 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1a52a623/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala 
b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
index d3c84b7..e185bc8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
@@ -38,7 +38,7 @@ object Correlation {
 
   /**
    * :: Experimental ::
-   * Compute the correlation matrix for the input RDD of Vectors using the 
specified method.
+   * Compute the correlation matrix for the input Dataset of Vectors using the 
specified method.
    * Methods currently supported: `pearson` (default), `spearman`.
    *
    * @param dataset A dataset or a dataframe
@@ -56,14 +56,14 @@ object Correlation {
    *  Here is how to access the correlation coefficient:
    *  {{{
    *    val data: Dataset[Vector] = ...
-   *    val Row(coeff: Matrix) = Statistics.corr(data, "value").head
+   *    val Row(coeff: Matrix) = Correlation.corr(data, "value").head
    *    // coeff now contains the Pearson correlation matrix.
    *  }}}
    *
    * @note For Spearman, a rank correlation, we need to create an RDD[Double] 
for each column
    * and sort it in order to retrieve the ranks and then join the columns back 
into an RDD[Vector],
-   * which is fairly costly. Cache the input RDD before calling corr with 
`method = "spearman"` to
-   * avoid recomputing the common lineage.
+   * which is fairly costly. Cache the input Dataset before calling corr with 
`method = "spearman"`
+   * to avoid recomputing the common lineage.
    */
   @Since("2.2.0")
   def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {

http://git-wip-us.apache.org/repos/asf/spark/blob/1a52a623/python/pyspark/ml/stat.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index db043ff..079b083 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -71,6 +71,67 @@ class ChiSquareTest(object):
         return _java2py(sc, javaTestObj.test(*args))
 
 
+class Correlation(object):
+    """
+    .. note:: Experimental
+
+    Compute the correlation matrix for the input dataset of Vectors using the 
specified method.
+    Methods currently supported: `pearson` (default), `spearman`.
+
+    .. note:: For Spearman, a rank correlation, we need to create an 
RDD[Double] for each column
+      and sort it in order to retrieve the ranks and then join the columns 
back into an RDD[Vector],
+      which is fairly costly. Cache the input Dataset before calling corr with 
`method = 'spearman'`
+      to avoid recomputing the common lineage.
+
+    :param dataset:
+      A dataset or a dataframe.
+    :param column:
+      The name of the column of vectors for which the correlation coefficient 
needs
+      to be computed. This must be a column of the dataset, and it must contain
+      Vector objects.
+    :param method:
+      String specifying the method to use for computing correlation.
+      Supported: `pearson` (default), `spearman`.
+    :return:
+      A dataframe that contains the correlation matrix of the column of 
vectors. This
+      dataframe contains a single row and a single column of name
+      '$METHODNAME($COLUMN)'.
+
+    >>> from pyspark.ml.linalg import Vectors
+    >>> from pyspark.ml.stat import Correlation
+    >>> dataset = [[Vectors.dense([1, 0, 0, -2])],
+    ...            [Vectors.dense([4, 5, 0, 3])],
+    ...            [Vectors.dense([6, 7, 0, 8])],
+    ...            [Vectors.dense([9, 0, 0, 1])]]
+    >>> dataset = spark.createDataFrame(dataset, ['features'])
+    >>> pearsonCorr = Correlation.corr(dataset, 'features', 
'pearson').collect()[0][0]
+    >>> print(str(pearsonCorr).replace('nan', 'NaN'))
+    DenseMatrix([[ 1.        ,  0.0556...,         NaN,  0.4004...],
+                 [ 0.0556...,  1.        ,         NaN,  0.9135...],
+                 [        NaN,         NaN,  1.        ,         NaN],
+                 [ 0.4004...,  0.9135...,         NaN,  1.        ]])
+    >>> spearmanCorr = Correlation.corr(dataset, 'features', 
method='spearman').collect()[0][0]
+    >>> print(str(spearmanCorr).replace('nan', 'NaN'))
+    DenseMatrix([[ 1.        ,  0.1054...,         NaN,  0.4       ],
+                 [ 0.1054...,  1.        ,         NaN,  0.9486... ],
+                 [        NaN,         NaN,  1.        ,         NaN],
+                 [ 0.4       ,  0.9486... ,         NaN,  1.        ]])
+
+    .. versionadded:: 2.2.0
+
+    """
+    @staticmethod
+    @since("2.2.0")
+    def corr(dataset, column, method="pearson"):
+        """
+        Compute the correlation matrix with specified method using dataset.
+        """
+        sc = SparkContext._active_spark_context
+        javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation
+        args = [_py2java(sc, arg) for arg in (dataset, column, method)]
+        return _java2py(sc, javaCorrObj.corr(*args))
+
+
 if __name__ == "__main__":
     import doctest
     import pyspark.ml.stat


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to