Repository: spark Updated Branches: refs/heads/branch-1.6 85bb319a2 -> a6ee4f989
[SPARK-9301][SQL] Add collect_set and collect_list aggregate functions For now they are thin wrappers around the corresponding Hive UDAFs. One limitation with these in Hive 0.13.0 is they only support aggregating primitive types. I chose snake_case here instead of camelCase because it seems to be used in the majority of the multi-word fns. Do we also want to add these to `functions.py`? This approach was recommended here: https://github.com/apache/spark/pull/8592#issuecomment-154247089 marmbrus rxin Author: Nick Buroojy <nick.buro...@civitaslearning.com> Closes #9526 from nburoojy/nick/udaf-alias. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a6ee4f98 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a6ee4f98 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a6ee4f98 Branch: refs/heads/branch-1.6 Commit: a6ee4f989d020420dd08b97abb24802200ff23b2 Parents: 85bb319 Author: Nick Buroojy <nick.buro...@civitaslearning.com> Authored: Mon Nov 9 14:30:37 2015 -0800 Committer: Michael Armbrust <mich...@databricks.com> Committed: Mon Nov 9 14:30:37 2015 -0800 ---------------------------------------------------------------------- python/pyspark/sql/functions.py | 25 +++++++++++--------- python/pyspark/sql/tests.py | 17 +++++++++++++ .../scala/org/apache/spark/sql/functions.scala | 20 ++++++++++++++++ .../sql/hive/HiveDataFrameAnalyticsSuite.scala | 15 ++++++++++-- 4 files changed, 64 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/a6ee4f98/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2f7c2f4..962f676 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -124,17 +124,20 @@ _functions_1_4 = { _functions_1_6 = { # unary math functions - "stddev": "Aggregate function: returns the unbiased sample standard deviation of" + - " the expression in a group.", - "stddev_samp": "Aggregate function: returns the unbiased sample standard deviation of" + - " the expression in a group.", - "stddev_pop": "Aggregate function: returns population standard deviation of" + - " the expression in a group.", - "variance": "Aggregate function: returns the population variance of the values in a group.", - "var_samp": "Aggregate function: returns the unbiased variance of the values in a group.", - "var_pop": "Aggregate function: returns the population variance of the values in a group.", - "skewness": "Aggregate function: returns the skewness of the values in a group.", - "kurtosis": "Aggregate function: returns the kurtosis of the values in a group." + 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' + + ' the expression in a group.', + 'stddev_samp': 'Aggregate function: returns the unbiased sample standard deviation of' + + ' the expression in a group.', + 'stddev_pop': 'Aggregate function: returns population standard deviation of' + + ' the expression in a group.', + 'variance': 'Aggregate function: returns the population variance of the values in a group.', + 'var_samp': 'Aggregate function: returns the unbiased variance of the values in a group.', + 'var_pop': 'Aggregate function: returns the population variance of the values in a group.', + 'skewness': 'Aggregate function: returns the skewness of the values in a group.', + 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.', + 'collect_list': 'Aggregate function: returns a list of objects with duplicates.', + 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' + + ' eliminated.' } # math functions that take two arguments as input http://git-wip-us.apache.org/repos/asf/spark/blob/a6ee4f98/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4c03a0d..e224574 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1230,6 +1230,23 @@ class HiveContextSQLTests(ReusedPySparkTestCase): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_collect_functions(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql import functions + + self.assertEqual( + sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r), + [1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r), + [1, 1, 1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r), + ["1", "2"]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), + ["1", "2", "2", "2"]) + if __name__ == "__main__": if xmlrunner: http://git-wip-us.apache.org/repos/asf/spark/blob/a6ee4f98/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0462758..3f0b24b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -175,6 +175,26 @@ object functions { def avg(columnName: String): Column = avg(Column(columnName)) /** + * Aggregate function: returns a list of objects with duplicates. + * + * For now this is an alias for the collect_list Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_list(e: Column): Column = callUDF("collect_list", e) + + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * For now this is an alias for the collect_set Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_set(e: Column): Column = callUDF("collect_set", e) + + /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. * * @group agg_funcs http://git-wip-us.apache.org/repos/asf/spark/blob/a6ee4f98/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 2e5cae4..9864acf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.scalatest.BeforeAndAfterAll @@ -32,7 +32,7 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with private var testData: DataFrame = _ override def beforeAll() { - testData = Seq((1, 2), (2, 4)).toDF("a", "b") + testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") } @@ -52,6 +52,17 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with ) } + test("collect functions") { + checkAnswer( + testData.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4))) + ) + checkAnswer( + testData.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 4))) + ) + } + test("cube") { checkAnswer( testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org