Repository: spark
Updated Branches:
  refs/heads/master 93581fbc1 -> d27daa54b


[SPARK-19636][ML] Feature parity for correlation statistics in MLlib

## What changes were proposed in this pull request?

This patch adds the Dataframes-based support for the correlation statistics 
found in the `org.apache.spark.mllib.stat.correlation.Statistics`, following 
the design doc discussed in the JIRA ticket.

The current implementation is a simple wrapper around the `spark.mllib` 
implementation. Future optimizations can be implemented at a later stage.

## How was this patch tested?

```
build/sbt "testOnly org.apache.spark.ml.stat.StatisticsSuite"
```

Author: Timothy Hunter <timhun...@databricks.com>

Closes #17108 from thunterdb/19636.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d27daa54
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d27daa54
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d27daa54

Branch: refs/heads/master
Commit: d27daa54bd341b29737a6352d9a1055151248ae7
Parents: 93581fb
Author: Timothy Hunter <timhun...@databricks.com>
Authored: Thu Mar 23 18:42:13 2017 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Thu Mar 23 18:42:13 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/util/TestingUtils.scala |  8 ++
 .../org/apache/spark/ml/stat/Correlation.scala  | 86 ++++++++++++++++++++
 .../apache/spark/ml/stat/CorrelationSuite.scala | 77 ++++++++++++++++++
 3 files changed, 171 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d27daa54/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
----------------------------------------------------------------------
diff --git 
a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala 
b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
index 2327917..30edd00 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
@@ -32,6 +32,10 @@ object TestingUtils {
    * the relative tolerance is meaningless, so the exception will be raised to 
warn users.
    */
   private def RelativeErrorComparison(x: Double, y: Double, eps: Double): 
Boolean = {
+    // Special case for NaNs
+    if (x.isNaN && y.isNaN) {
+      return true
+    }
     val absX = math.abs(x)
     val absY = math.abs(y)
     val diff = math.abs(x - y)
@@ -49,6 +53,10 @@ object TestingUtils {
    * Private helper function for comparing two values using absolute tolerance.
    */
   private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): 
Boolean = {
+    // Special case for NaNs
+    if (x.isNaN && y.isNaN) {
+      return true
+    }
     math.abs(x - y) < eps
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d27daa54/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala 
b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
new file mode 100644
index 0000000..a7243cc
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.linalg.{SQLDataTypes, Vector}
+import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
+import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.types.{StructField, StructType}
+
+/**
+ * API for correlation functions in MLlib, compatible with Dataframes and 
Datasets.
+ *
+ * The functions in this package generalize the functions in 
[[org.apache.spark.sql.Dataset.stat]]
+ * to spark.ml's Vector types.
+ */
+@Since("2.2.0")
+@Experimental
+object Correlation {
+
+  /**
+   * :: Experimental ::
+   * Compute the correlation matrix for the input RDD of Vectors using the 
specified method.
+   * Methods currently supported: `pearson` (default), `spearman`.
+   *
+   * @param dataset A dataset or a dataframe
+   * @param column The name of the column of vectors for which the correlation 
coefficient needs
+   *               to be computed. This must be a column of the dataset, and 
it must contain
+   *               Vector objects.
+   * @param method String specifying the method to use for computing 
correlation.
+   *               Supported: `pearson` (default), `spearman`
+   * @return A dataframe that contains the correlation matrix of the column of 
vectors. This
+   *         dataframe contains a single row and a single column of name
+   *         '$METHODNAME($COLUMN)'.
+   * @throws IllegalArgumentException if the column is not a valid column in 
the dataset, or if
+   *                                  the content of this column is not of 
type Vector.
+   *
+   *  Here is how to access the correlation coefficient:
+   *  {{{
+   *    val data: Dataset[Vector] = ...
+   *    val Row(coeff: Matrix) = Statistics.corr(data, "value").head
+   *    // coeff now contains the Pearson correlation matrix.
+   *  }}}
+   *
+   * @note For Spearman, a rank correlation, we need to create an RDD[Double] 
for each column
+   * and sort it in order to retrieve the ranks and then join the columns back 
into an RDD[Vector],
+   * which is fairly costly. Cache the input RDD before calling corr with 
`method = "spearman"` to
+   * avoid recomputing the common lineage.
+   */
+  @Since("2.2.0")
+  def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
+    val rdd = dataset.select(column).rdd.map {
+      case Row(v: Vector) => OldVectors.fromML(v)
+    }
+    val oldM = OldStatistics.corr(rdd, method)
+    val name = s"$method($column)"
+    val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, 
nullable = false)))
+    dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema)
+  }
+
+  /**
+   * Compute the Pearson correlation matrix for the input Dataset of Vectors.
+   */
+  @Since("2.2.0")
+  def corr(dataset: Dataset[_], column: String): DataFrame = {
+    corr(dataset, column, "pearson")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d27daa54/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala
new file mode 100644
index 0000000..7d935e6
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat
+
+import breeze.linalg.{DenseMatrix => BDM}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.linalg.{Matrices, Matrix, Vectors}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row}
+
+
+class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with 
Logging {
+
+  val xData = Array(1.0, 0.0, -2.0)
+  val yData = Array(4.0, 5.0, 3.0)
+  val zeros = new Array[Double](3)
+  val data = Seq(
+    Vectors.dense(1.0, 0.0, 0.0, -2.0),
+    Vectors.dense(4.0, 5.0, 0.0, 3.0),
+    Vectors.dense(6.0, 7.0, 0.0, 8.0),
+    Vectors.dense(9.0, 0.0, 0.0, 1.0)
+  )
+
+  private def X = 
spark.createDataFrame(data.map(Tuple1.apply)).toDF("features")
+
+  private def extract(df: DataFrame): BDM[Double] = {
+    val Array(Row(mat: Matrix)) = df.collect()
+    mat.asBreeze.toDenseMatrix
+  }
+
+
+  test("corr(X) default, pearson") {
+    val defaultMat = Correlation.corr(X, "features")
+    val pearsonMat = Correlation.corr(X, "features", "pearson")
+    // scalastyle:off
+    val expected = Matrices.fromBreeze(BDM(
+      (1.00000000, 0.05564149, Double.NaN, 0.4004714),
+      (0.05564149, 1.00000000, Double.NaN, 0.9135959),
+      (Double.NaN, Double.NaN, 1.00000000, Double.NaN),
+      (0.40047142, 0.91359586, Double.NaN, 1.0000000)))
+    // scalastyle:on
+
+    assert(Matrices.fromBreeze(extract(defaultMat)) ~== expected absTol 1e-4)
+    assert(Matrices.fromBreeze(extract(pearsonMat)) ~== expected absTol 1e-4)
+  }
+
+  test("corr(X) spearman") {
+    val spearmanMat = Correlation.corr(X, "features", "spearman")
+    // scalastyle:off
+    val expected = Matrices.fromBreeze(BDM(
+      (1.0000000,  0.1054093,  Double.NaN, 0.4000000),
+      (0.1054093,  1.0000000,  Double.NaN, 0.9486833),
+      (Double.NaN, Double.NaN, 1.00000000, Double.NaN),
+      (0.4000000,  0.9486833,  Double.NaN, 1.0000000)))
+    // scalastyle:on
+    assert(Matrices.fromBreeze(extract(spearmanMat)) ~== expected absTol 1e-4)
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to