Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/19872#discussion_r159108798 --- Diff: python/pyspark/sql/tests.py --- @@ -4052,6 +4066,323 @@ def test_unsupported_types(self): df.groupby('id').apply(f).collect() +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class GroupbyAggTests(ReusedSQLTestCase): + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def plus_one(self): + from pyspark.sql.functions import udf + + @udf('double') + def plus_one(v): + assert isinstance(v, float) + return v + 1 + return plus_one + + @property + def plus_two(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.SCALAR) + def plus_two(v): + assert isinstance(v, pd.Series) + return v + 2 + return plus_two + + @property + def mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def mean_udf(v): + return v.mean() + return mean_udf + + @property + def sum_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def sum_udf(v): + return v.sum() + return sum_udf + + @property + def weighted_mean_udf(self): + import numpy as np + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def weighted_mean_udf(v, w): + return np.average(v, weights=w) + return weighted_mean_udf + + def test_basic(self): + from pyspark.sql.functions import col, lit, sum, mean + + self.spark.conf.set("spark.sql.codegen.wholeStage", False) + + df = self.data + weighted_mean_udf = self.weighted_mean_udf + + result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id') + expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean_udf(v, 1.0)')).sort('id') + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\ + .sort(df.id + 1) + expected2 = df.groupby((col('id') + 1))\ + .agg(mean(df.v).alias('weighted_mean_udf(v, 1.0)')).sort(df.id + 1) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id') + expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean_udf(v, w)')).sort('id') + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + result4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(weighted_mean_udf(df.v, df.w))\ + .sort('id') + expected4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(mean(df.v).alias('weighted_mean_udf(v, w)'))\ + .sort('id') + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_array(self): + from pyspark.sql.types import ArrayType, DoubleType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): + @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return [v.mean(), v.std()] + + def test_struct(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): + @pandas_udf('mean double, std double', PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return (v.mean(), v.std()) + + def test_alias(self): + from pyspark.sql.functions import mean + + df = self.data + mean_udf = self.mean_udf + + result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')) + expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + from pyspark.sql.functions import sum, mean + + df = self.data + sum_udf = self.sum_udf + + result1 = (df.groupby('id') + .agg(sum_udf(df.v) + 1) + .sort('id')) + + expected1 = (df.groupby('id') + .agg((sum(df.v) + 1).alias('(sum_udf(v) + 1)')) + .sort('id')) + + result2 = (df.groupby('id') + .agg(sum_udf(df.v + 1)) + .sort('id')) + + expected2 = (df.groupby('id') + .agg(sum(df.v + 1).alias('sum_udf((v + 1))')) + .sort('id')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_mixed_udf(self): + from pyspark.sql.functions import sum, mean + + df = self.data + plus_one = self.plus_one + plus_two = self.plus_two + sum_udf = self.sum_udf + + result1 = (df.groupby('id') + .agg(plus_one(sum_udf(df.v))) + .sort('id')) + + expected1 = (df.groupby('id') + .agg(plus_one(sum(df.v)).alias("plus_one(sum_udf(v))")) + .sort('id')) + + result2 = (df.groupby('id') + .agg(sum_udf(plus_one(df.v))) + .sort('id')) + + expected2 = (df.groupby('id') + .agg(sum(df.v + 1).alias("sum_udf(plus_one(v))")) + .sort('id')) + + result3 = (df.groupby('id') + .agg(sum_udf(plus_two(df.v))) + .sort('id')) + + expected3 = (df.groupby('id') + .agg(sum(df.v + 2).alias("sum_udf(plus_two(v))")) + .sort('id')) + + result4 = (df.groupby('id') + .agg(plus_two(sum_udf(df.v))) + .sort('id')) + + expected4 = (df.groupby('id') + .agg(plus_two(sum(df.v)).alias("plus_two(sum_udf(v))")) + .sort('id')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_multiple(self): + from pyspark.sql.functions import col, lit, sum, mean + + df = self.data + mean_udf = self.mean_udf + sum_udf = self.sum_udf + weighted_mean_udf = self.weighted_mean_udf + + result1 = (df.groupBy('id') + .agg(mean_udf(df.v), + sum_udf(df.v), + weighted_mean_udf(df.v, df.w)) + .sort('id') + .toPandas()) + + expected1 = (df.groupBy('id') + .agg(mean(df.v).alias('mean_udf(v)'), + sum(df.v).alias('sum_udf(v)'), + mean(df.v).alias('weighted_mean_udf(v, w)')) + .sort('id') + .toPandas()) + + result2 = (df.groupBy('id', 'v') + .agg(mean_udf(df.v), + sum_udf(df.id)) + .sort('id', 'v') + .toPandas()) + + expected2 = (df.groupBy('id', 'v') + .agg(mean_udf(df.v).alias('mean_udf(v)'), + sum_udf(df.id).alias('sum_udf(id)')) + .sort('id', 'v') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + + def test_complex(self): + from pyspark.sql.functions import col, sum + + df = self.data + plus_one = self.plus_one + plus_two = self.plus_two + sum_udf = self.sum_udf + + result1 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby('id') + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_one(sum_udf(col('v1'))), + sum_udf(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + expected1 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby('id') + .agg(sum(col('v')).alias('sum_udf(v)'), + sum(col('v1') + 3).alias('sum_udf((v1 + 3))'), + (sum(col('v2')) + 5).alias('(sum_udf(v2) + 5)'), + plus_one(sum(col('v1'))).alias('plus_one(sum_udf(v1))'), + sum(col('v2') + 1).alias('sum_udf(plus_one(v2))')) + .sort('id') + .toPandas()) + + result2 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby('id') + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_two(sum_udf(col('v1'))), + sum_udf(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + expected2 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby('id') + .agg(sum(col('v')).alias('sum_udf(v)'), + sum(col('v1') + 3).alias('sum_udf((v1 + 3))'), + (sum(col('v2')) + 5).alias('(sum_udf(v2) + 5)'), + plus_two(sum(col('v1'))).alias('plus_two(sum_udf(v1))'), + sum(col('v2') + 2).alias('sum_udf(plus_two(v2))')) + .sort('id') + .toPandas()) + + result3 = (df.groupby('id') + .agg(sum_udf(df.v).alias('v')) + .groupby('id') + .agg(sum_udf(col('v')).alias('sum_v')) + .sort('id') + .toPandas()) + + expected3 = (df.groupby('id') + .agg(sum(df.v).alias('v')) + .groupby('id') + .agg(sum(col('v')).alias('sum_v')) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) + --- End diff -- Added `test_complex_grouping`
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org