Repository: spark
Updated Branches:
  refs/heads/branch-1.6 85bb319a2 -> a6ee4f989


[SPARK-9301][SQL] Add collect_set and collect_list aggregate functions

For now they are thin wrappers around the corresponding Hive UDAFs.

One limitation with these in Hive 0.13.0 is they only support aggregating 
primitive types.

I chose snake_case here instead of camelCase because it seems to be used in the 
majority of the multi-word fns.

Do we also want to add these to `functions.py`?

This approach was recommended here: 
https://github.com/apache/spark/pull/8592#issuecomment-154247089

marmbrus rxin

Author: Nick Buroojy <nick.buro...@civitaslearning.com>

Closes #9526 from nburoojy/nick/udaf-alias.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a6ee4f98
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a6ee4f98
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a6ee4f98

Branch: refs/heads/branch-1.6
Commit: a6ee4f989d020420dd08b97abb24802200ff23b2
Parents: 85bb319
Author: Nick Buroojy <nick.buro...@civitaslearning.com>
Authored: Mon Nov 9 14:30:37 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon Nov 9 14:30:37 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 25 +++++++++++---------
 python/pyspark/sql/tests.py                     | 17 +++++++++++++
 .../scala/org/apache/spark/sql/functions.scala  | 20 ++++++++++++++++
 .../sql/hive/HiveDataFrameAnalyticsSuite.scala  | 15 ++++++++++--
 4 files changed, 64 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a6ee4f98/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 2f7c2f4..962f676 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -124,17 +124,20 @@ _functions_1_4 = {
 
 _functions_1_6 = {
     # unary math functions
-    "stddev": "Aggregate function: returns the unbiased sample standard 
deviation of" +
-              " the expression in a group.",
-    "stddev_samp": "Aggregate function: returns the unbiased sample standard 
deviation of" +
-              " the expression in a group.",
-    "stddev_pop": "Aggregate function: returns population standard deviation 
of" +
-              " the expression in a group.",
-    "variance": "Aggregate function: returns the population variance of the 
values in a group.",
-    "var_samp": "Aggregate function: returns the unbiased variance of the 
values in a group.",
-    "var_pop":  "Aggregate function: returns the population variance of the 
values in a group.",
-    "skewness": "Aggregate function: returns the skewness of the values in a 
group.",
-    "kurtosis": "Aggregate function: returns the kurtosis of the values in a 
group."
+    'stddev': 'Aggregate function: returns the unbiased sample standard 
deviation of' +
+              ' the expression in a group.',
+    'stddev_samp': 'Aggregate function: returns the unbiased sample standard 
deviation of' +
+                   ' the expression in a group.',
+    'stddev_pop': 'Aggregate function: returns population standard deviation 
of' +
+                  ' the expression in a group.',
+    'variance': 'Aggregate function: returns the population variance of the 
values in a group.',
+    'var_samp': 'Aggregate function: returns the unbiased variance of the 
values in a group.',
+    'var_pop':  'Aggregate function: returns the population variance of the 
values in a group.',
+    'skewness': 'Aggregate function: returns the skewness of the values in a 
group.',
+    'kurtosis': 'Aggregate function: returns the kurtosis of the values in a 
group.',
+    'collect_list': 'Aggregate function: returns a list of objects with 
duplicates.',
+    'collect_set': 'Aggregate function: returns a set of objects with 
duplicate elements' +
+                   ' eliminated.'
 }
 
 # math functions that take two arguments as input

http://git-wip-us.apache.org/repos/asf/spark/blob/a6ee4f98/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 4c03a0d..e224574 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1230,6 +1230,23 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
         for r, ex in zip(rs, expected):
             self.assertEqual(tuple(r), ex[:len(r)])
 
+    def test_collect_functions(self):
+        df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, 
"2")], ["key", "value"])
+        from pyspark.sql import functions
+
+        self.assertEqual(
+            
sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r),
+            [1, 2])
+        self.assertEqual(
+            
sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r),
+            [1, 1, 1, 2])
+        self.assertEqual(
+            
sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r),
+            ["1", "2"])
+        self.assertEqual(
+            
sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
+            ["1", "2", "2", "2"])
+
 
 if __name__ == "__main__":
     if xmlrunner:

http://git-wip-us.apache.org/repos/asf/spark/blob/a6ee4f98/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 0462758..3f0b24b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -175,6 +175,26 @@ object functions {
   def avg(columnName: String): Column = avg(Column(columnName))
 
   /**
+   * Aggregate function: returns a list of objects with duplicates.
+   *
+   * For now this is an alias for the collect_list Hive UDAF.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def collect_list(e: Column): Column = callUDF("collect_list", e)
+
+  /**
+   * Aggregate function: returns a set of objects with duplicate elements 
eliminated.
+   *
+   * For now this is an alias for the collect_set Hive UDAF.
+   *
+   * @group agg_funcs
+   * @since 1.6.0
+   */
+  def collect_set(e: Column): Column = callUDF("collect_set", e)
+
+  /**
    * Aggregate function: returns the Pearson Correlation Coefficient for two 
columns.
    *
    * @group agg_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/a6ee4f98/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
index 2e5cae4..9864acf 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.hive
 
-import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.{DataFrame, QueryTest, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.scalatest.BeforeAndAfterAll
@@ -32,7 +32,7 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with 
TestHiveSingleton with
   private var testData: DataFrame = _
 
   override def beforeAll() {
-    testData = Seq((1, 2), (2, 4)).toDF("a", "b")
+    testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b")
     hiveContext.registerDataFrameAsTable(testData, "mytable")
   }
 
@@ -52,6 +52,17 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with 
TestHiveSingleton with
     )
   }
 
+  test("collect functions") {
+    checkAnswer(
+      testData.select(collect_list($"a"), collect_list($"b")),
+      Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
+    )
+    checkAnswer(
+      testData.select(collect_set($"a"), collect_set($"b")),
+      Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
+    )
+  }
+
   test("cube") {
     checkAnswer(
       testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to