Repository: spark Updated Branches: refs/heads/master 7b5dd3e3c -> 4dc8d7449
[SPARK-7240][SQL] Single pass covariance calculation for dataframes Added the calculation of covariance between two columns to DataFrames. cc mengxr rxin Author: Burak Yavuz <brk...@gmail.com> Closes #5825 from brkyvz/df-cov and squashes the following commits: cb18046 [Burak Yavuz] changed to sample covariance f2e862b [Burak Yavuz] fixed failed test 51e39b8 [Burak Yavuz] moved implementation 0c6a759 [Burak Yavuz] addressed math comments 8456eca [Burak Yavuz] fix pyStyle3 aa2ad29 [Burak Yavuz] fix pyStyle2 4e97a50 [Burak Yavuz] Merge branch 'master' of github.com:apache/spark into df-cov e3b0b85 [Burak Yavuz] addressed comments v0.1 a7115f1 [Burak Yavuz] fix python style 7dc6dbc [Burak Yavuz] reorder imports 408cb77 [Burak Yavuz] initial commit Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4dc8d744 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4dc8d744 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4dc8d744 Branch: refs/heads/master Commit: 4dc8d74491b101a794cf8d386d8c5ebc6019b75f Parents: 7b5dd3e Author: Burak Yavuz <brk...@gmail.com> Authored: Fri May 1 13:29:17 2015 -0700 Committer: Reynold Xin <r...@databricks.com> Committed: Fri May 1 13:29:17 2015 -0700 ---------------------------------------------------------------------- python/pyspark/sql/__init__.py | 4 +- python/pyspark/sql/dataframe.py | 36 ++++++++- python/pyspark/sql/tests.py | 5 ++ .../spark/sql/DataFrameStatFunctions.scala | 12 ++- .../sql/execution/stat/StatFunctions.scala | 80 ++++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 7 ++ .../apache/spark/sql/DataFrameStatSuite.scala | 18 ++++- 7 files changed, 157 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/4dc8d744/python/pyspark/sql/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 6d54b9e..b60b991 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -54,7 +54,9 @@ del modname, sys from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions +from pyspark.sql.dataframe import DataFrameStatFunctions __all__ = [ - 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions' + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', + 'DataFrameNaFunctions', 'DataFrameStatFunctions' ] http://git-wip-us.apache.org/repos/asf/spark/blob/4dc8d744/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5908ebc..1f08c2d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -34,7 +34,8 @@ from pyspark.sql.types import * from pyspark.sql.types import _create_cls, _parse_datatype_json_string -__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions"] +__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions", + "DataFrameStatFunctions"] class DataFrame(object): @@ -93,6 +94,12 @@ class DataFrame(object): """ return DataFrameNaFunctions(self) + @property + def stat(self): + """Returns a :class:`DataFrameStatFunctions` for statistic functions. + """ + return DataFrameStatFunctions(self) + @ignore_unicode_prefix def toJSON(self, use_unicode=True): """Converts a :class:`DataFrame` into a :class:`RDD` of string. @@ -868,6 +875,20 @@ class DataFrame(object): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + def cov(self, col1, col2): + """ + Calculate the sample covariance for the given columns, specified by their names, as a + double value. :func:`DataFrame.cov` and :func:`DataFrameStatFunctions.cov` are aliases. + + :param col1: The name of the first column + :param col2: The name of the second column + """ + if not isinstance(col1, str): + raise ValueError("col1 should be a string.") + if not isinstance(col2, str): + raise ValueError("col2 should be a string.") + return self._jdf.stat().cov(col1, col2) + @ignore_unicode_prefix def withColumn(self, colName, col): """Returns a new :class:`DataFrame` by adding a column. @@ -1311,6 +1332,19 @@ class DataFrameNaFunctions(object): fill.__doc__ = DataFrame.fillna.__doc__ +class DataFrameStatFunctions(object): + """Functionality for statistic functions with :class:`DataFrame`. + """ + + def __init__(self, df): + self.df = df + + def cov(self, col1, col2): + return self.df.cov(col1, col2) + + cov.__doc__ = DataFrame.cov.__doc__ + + def _test(): import doctest from pyspark.context import SparkContext http://git-wip-us.apache.org/repos/asf/spark/blob/4dc8d744/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5640bb5..44c8b6a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -387,6 +387,11 @@ class SQLTests(ReusedPySparkTestCase): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_cov(self): + df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() + cov = df.stat.cov("a", "b") + self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) + def test_math_functions(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() from pyspark.sql import mathfunctions as functions http://git-wip-us.apache.org/repos/asf/spark/blob/4dc8d744/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 42e5cbc..23652ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.stat.FrequentItems +import org.apache.spark.sql.execution.stat._ /** * :: Experimental :: @@ -65,4 +65,14 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: List[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } + + /** + * Calculate the sample covariance of two numerical columns of a DataFrame. + * @param col1 the name of the first column + * @param col2 the name of the second column + * @return the covariance of the two columns. + */ + def cov(col1: String, col2: String): Double = { + StatFunctions.calculateCov(df, Seq(col1, col2)) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/4dc8d744/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala new file mode 100644 index 0000000..d4a94c2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.stat + +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.types.{DoubleType, NumericType} + +private[sql] object StatFunctions { + + /** Helper class to simplify tracking and merging counts. */ + private class CovarianceCounter extends Serializable { + var xAvg = 0.0 + var yAvg = 0.0 + var Ck = 0.0 + var count = 0L + // add an example to the calculation + def add(x: Double, y: Double): this.type = { + val oldX = xAvg + count += 1 + xAvg += (x - xAvg) / count + yAvg += (y - yAvg) / count + Ck += (y - yAvg) * (x - oldX) + this + } + // merge counters from other partitions. Formula can be found at: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Covariance + def merge(other: CovarianceCounter): this.type = { + val totalCount = count + other.count + Ck += other.Ck + + (xAvg - other.xAvg) * (yAvg - other.yAvg) * count / totalCount * other.count + xAvg = (xAvg * count + other.xAvg * other.count) / totalCount + yAvg = (yAvg * count + other.yAvg * other.count) / totalCount + count = totalCount + this + } + // return the sample covariance for the observed examples + def cov: Double = Ck / (count - 1) + } + + /** + * Calculate the covariance of two numerical columns of a DataFrame. + * @param df The DataFrame + * @param cols the column names + * @return the covariance of the two columns. + */ + private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { + require(cols.length == 2, "Currently cov supports calculating the covariance " + + "between two columns.") + cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => + require(data.nonEmpty, s"Couldn't find column with name $name") + require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " + + s"with dataType ${data.get.dataType} not supported.") + } + val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) + val counts = df.select(columns:_*).rdd.aggregate(new CovarianceCounter)( + seqOp = (counter, row) => { + counter.add(row.getDouble(0), row.getDouble(1)) + }, + combOp = (baseCounter, other) => { + baseCounter.merge(other) + }) + counts.cov + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/4dc8d744/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ebe96e6..96fe66d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -186,4 +186,11 @@ public class JavaDataFrameSuite { DataFrame results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } + + @Test + public void testCovariance() { + DataFrame df = context.table("testData2"); + Double result = df.stat().cov("a", "b"); + Assert.assertTrue(Math.abs(result) < 1e-6); + } } http://git-wip-us.apache.org/repos/asf/spark/blob/4dc8d744/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index bb1d29c..4f5a2ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -25,10 +25,11 @@ import org.apache.spark.sql.test.TestSQLContext.implicits._ class DataFrameStatSuite extends FunSuite { + import TestData._ val sqlCtx = TestSQLContext - + def toLetter(i: Int): String = (i + 97).toChar.toString + test("Frequent Items") { - def toLetter(i: Int): String = (i + 96).toChar.toString val rows = Array.tabulate(1000) { i => if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) } @@ -44,4 +45,17 @@ class DataFrameStatSuite extends FunSuite { items2.getSeq[Double](0) should contain (-1.0) } + + test("covariance") { + val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i))) + val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters") + + val results = df.stat.cov("singles", "doubles") + assert(math.abs(results - 55.0 / 3) < 1e-6) + intercept[IllegalArgumentException] { + df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes + } + val decimalRes = decimalData.stat.cov("a", "b") + assert(math.abs(decimalRes) < 1e-6) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org