Repository: spark
Updated Branches:
  refs/heads/master c3a52a082 -> 784fcd532


[SPARK-6117] [SQL] Improvements to DataFrame.describe()

1. Slightly modifications to the code to make it more readable.
2. Added Python implementation.
3. Updated the documentation to state that we don't guarantee the output schema 
for this function and it should only be used for exploratory data analysis.

Author: Reynold Xin <r...@databricks.com>

Closes #5201 from rxin/df-describe and squashes the following commits:

25a7834 [Reynold Xin] Reset run-tests.
6abdfee [Reynold Xin] [SPARK-6117] [SQL] Improvements to DataFrame.describe()


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/784fcd53
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/784fcd53
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/784fcd53

Branch: refs/heads/master
Commit: 784fcd532784fcfd9bf0a1db71c9f71c469ee716
Parents: c3a52a0
Author: Reynold Xin <r...@databricks.com>
Authored: Thu Mar 26 12:26:13 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Thu Mar 26 12:26:13 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 19 ++++++++
 .../scala/org/apache/spark/sql/DataFrame.scala  | 46 ++++++++++++--------
 .../org/apache/spark/sql/DataFrameSuite.scala   |  3 +-
 3 files changed, 48 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/784fcd53/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index bf7c47b..d51309f 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -520,6 +520,25 @@ class DataFrame(object):
 
     orderBy = sort
 
+    def describe(self, *cols):
+        """Computes statistics for numeric columns.
+
+        This include count, mean, stddev, min, and max. If no columns are
+        given, this function computes statistics for all numerical columns.
+
+        >>> df.describe().show()
+        summary age
+        count   2
+        mean    3.5
+        stddev  1.5
+        min     2
+        max     5
+        """
+        cols = ListConverter().convert(cols,
+                                       
self.sql_ctx._sc._gateway._gateway_client)
+        jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols))
+        return DataFrame(jdf, self.sql_ctx)
+
     def head(self, n=None):
         """ Return the first `n` rows or the first row if n is None.
 

http://git-wip-us.apache.org/repos/asf/spark/blob/784fcd53/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index db56182..4c80359 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.catalyst.{expressions, ScalaReflection, SqlParser}
+import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser}
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, 
ResolvedStar}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, 
LogicalRDD}
 import org.apache.spark.sql.jdbc.JDBCWriteDetails
 import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.types.{NumericType, StructType, StructField, 
StringType}
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.sources.{ResolvedDataSource, 
CreateTableUsingAsSelect}
 import org.apache.spark.util.Utils
 
@@ -752,15 +752,17 @@ class DataFrame private[sql](
   }
 
   /**
-   * Compute numerical statistics for given columns of this [[DataFrame]]:
-   * count, mean (avg), stddev (standard deviation), min, max.
-   * Each row of the resulting [[DataFrame]] contains column with statistic 
name
-   * and columns with statistic results for each given column.
-   * If no columns are given then computes for all numerical columns.
+   * Computes statistics for numeric columns, including count, mean, stddev, 
min, and max.
+   * If no columns are given, this function computes statistics for all 
numerical columns.
+   *
+   * This function is meant for exploratory data analysis, as we make no 
guarantee about the
+   * backward compatibility of the schema of the resulting [[DataFrame]]. If 
you want to
+   * programmatically compute summary statistics, use the `agg` function 
instead.
    *
    * {{{
-   *   df.describe("age", "height")
+   *   df.describe("age", "height").show()
    *
+   *   // output:
    *   // summary age   height
    *   // count   10.0  10.0
    *   // mean    53.3  178.05
@@ -768,13 +770,17 @@ class DataFrame private[sql](
    *   // min     18.0  163.0
    *   // max     92.0  192.0
    * }}}
+   *
+   * @group action
    */
   @scala.annotation.varargs
   def describe(cols: String*): DataFrame = {
 
-    def stddevExpr(expr: Expression) =
+    // TODO: Add stddev as an expression, and remove it from here.
+    def stddevExpr(expr: Expression): Expression =
       Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), 
Average(expr))))
 
+    // The list of summary statistics to compute, in the form of expressions.
     val statistics = List[(String, Expression => Expression)](
       "count" -> Count,
       "mean" -> Average,
@@ -782,24 +788,28 @@ class DataFrame private[sql](
       "min" -> Min,
       "max" -> Max)
 
-    val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else 
cols).toList
+    val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) 
else cols).toList
 
-    val localAgg = if (aggCols.nonEmpty) {
+    val ret: Seq[Row] = if (outputCols.nonEmpty) {
       val aggExprs = statistics.flatMap { case (_, colToAgg) =>
-        aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
+        outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
       }
 
-      agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
-        .grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, 
(statistic, _)) =>
-        Row(statistic :: aggregation.toList: _*)
+      val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
+
+      // Pivot the data so each summary is one row
+      row.grouped(outputCols.size).toSeq.zip(statistics).map {
+        case (aggregation, (statistic, _)) => Row(statistic :: 
aggregation.toList: _*)
       }
     } else {
+      // If there are no output columns, just output a single column that 
contains the stats.
       statistics.map { case (name, _) => Row(name) }
     }
 
-    val schema = StructType(("summary" :: aggCols).map(StructField(_, 
StringType)))
-    val rowRdd = sqlContext.sparkContext.parallelize(localAgg)
-    sqlContext.createDataFrame(rowRdd, schema)
+    // The first column is string type, and the rest are double type.
+    val schema = StructType(
+      StructField("summary", StringType) :: outputCols.map(StructField(_, 
DoubleType))).toAttributes
+    LocalRelation(schema, ret)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/784fcd53/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index afbedd1..fbc4065 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -444,7 +444,6 @@ class DataFrameSuite extends QueryTest {
   }
 
   test("describe") {
-
     val describeTestData = Seq(
       ("Bob",   16, 176),
       ("Alice", 32, 164),
@@ -465,7 +464,7 @@ class DataFrameSuite extends QueryTest {
       Row("min",     null, null),
       Row("max",     null, null))
 
-    def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq
+    def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
 
     val describeTwoCols = describeTestData.describe("age", "height")
     assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to