Github user icexelloss commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19872#discussion_r161855488
  
    --- Diff: python/pyspark/sql/tests.py ---
    @@ -4279,6 +4272,386 @@ def test_unsupported_types(self):
                     df.groupby('id').apply(f).collect()
     
     
    +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not 
installed")
    +class GroupbyAggTests(ReusedSQLTestCase):
    +
    +    @property
    +    def data(self):
    +        from pyspark.sql.functions import array, explode, col, lit
    +        return self.spark.range(10).toDF('id') \
    +            .withColumn("vs", array([lit(i * 1.0) + col('id') for i in 
range(20, 30)])) \
    +            .withColumn("v", explode(col('vs'))) \
    +            .drop('vs') \
    +            .withColumn('w', lit(1.0))
    +
    +    @property
    +    def plus_one(self):
    +        from pyspark.sql.functions import udf
    +
    +        @udf('double')
    +        def plus_one(v):
    +            assert isinstance(v, (int, float))
    +            return v + 1
    +        return plus_one
    +
    +    @property
    +    def plus_two(self):
    +        import pandas as pd
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        @pandas_udf('double', PandasUDFType.SCALAR)
    +        def plus_two(v):
    +            assert isinstance(v, pd.Series)
    +            return v + 2
    +        return plus_two
    +
    +    @property
    +    def mean_udf(self):
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        @pandas_udf('double', PandasUDFType.GROUP_AGG)
    +        def mean_udf(v):
    +            return v.mean()
    +        return mean_udf
    +
    +    @property
    +    def sum_udf(self):
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        @pandas_udf('double', PandasUDFType.GROUP_AGG)
    +        def sum_udf(v):
    +            return v.sum()
    +        return sum_udf
    +
    +    @property
    +    def weighted_mean_udf(self):
    +        import numpy as np
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        @pandas_udf('double', PandasUDFType.GROUP_AGG)
    +        def weighted_mean_udf(v, w):
    +            return np.average(v, weights=w)
    +        return weighted_mean_udf
    +
    +    def test_basic(self):
    +        from pyspark.sql.functions import col, lit, sum, mean
    +
    +        df = self.data
    +        weighted_mean_udf = self.weighted_mean_udf
    +
    +        result1 = df.groupby('id').agg(weighted_mean_udf(df.v, 
lit(1.0))).sort('id')
    +        expected1 = 
df.groupby('id').agg(mean(df.v).alias('weighted_mean_udf(v, 1.0)')).sort('id')
    +        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
    +
    +        result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, 
lit(1.0)))\
    +            .sort(df.id + 1)
    +        expected2 = df.groupby((col('id') + 1))\
    +            .agg(mean(df.v).alias('weighted_mean_udf(v, 1.0)')).sort(df.id 
+ 1)
    +        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
    +
    +        result3 = df.groupby('id').agg(weighted_mean_udf(df.v, 
df.w)).sort('id')
    +        expected3 = 
df.groupby('id').agg(mean(df.v).alias('weighted_mean_udf(v, w)')).sort('id')
    +        self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
    +
    +        result4 = df.groupby((col('id') + 1).alias('id'))\
    +            .agg(weighted_mean_udf(df.v, df.w))\
    +            .sort('id')
    +        expected4 = df.groupby((col('id') + 1).alias('id'))\
    +            .agg(mean(df.v).alias('weighted_mean_udf(v, w)'))\
    +            .sort('id')
    +        self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
    +
    +    def test_array(self):
    +        from pyspark.sql.types import ArrayType, DoubleType
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        with QuietTest(self.sc):
    +            with self.assertRaisesRegexp(NotImplementedError, 'not 
supported'):
    +                @pandas_udf(ArrayType(DoubleType()), 
PandasUDFType.GROUP_AGG)
    +                def mean_and_std_udf(v):
    +                    return [v.mean(), v.std()]
    +
    +    def test_struct(self):
    +        from pyspark.sql.functions import pandas_udf, PandasUDFType
    +
    +        with QuietTest(self.sc):
    +            with self.assertRaisesRegexp(NotImplementedError, 'not 
supported'):
    +                @pandas_udf('mean double, std double', 
PandasUDFType.GROUP_AGG)
    +                def mean_and_std_udf(v):
    +                    return (v.mean(), v.std())
    +
    +    def test_alias(self):
    +        from pyspark.sql.functions import mean
    +
    +        df = self.data
    +        mean_udf = self.mean_udf
    +
    +        result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias'))
    +        expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias'))
    +
    +        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
    +
    +    def test_mixed_sql(self):
    +        from pyspark.sql.functions import sum, mean
    +
    +        df = self.data
    +        sum_udf = self.sum_udf
    +
    +        result1 = (df.groupby('id')
    +                   .agg(sum_udf(df.v) + 1)
    +                   .sort('id'))
    +
    +        expected1 = (df.groupby('id')
    +                     .agg((sum(df.v) + 1).alias('(sum_udf(v) + 1)'))
    +                     .sort('id'))
    +
    +        result2 = (df.groupby('id')
    +                     .agg(sum_udf(df.v + 1))
    +                     .sort('id'))
    +
    +        expected2 = (df.groupby('id')
    +                       .agg(sum(df.v + 1).alias('sum_udf((v + 1))'))
    +                       .sort('id'))
    +
    +        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
    +        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
    +
    +    def test_mixed_udf(self):
    +        from pyspark.sql.functions import sum, mean
    +
    +        df = self.data
    +        plus_one = self.plus_one
    +        plus_two = self.plus_two
    +        sum_udf = self.sum_udf
    +
    +        result1 = (df.groupby('id')
    +                   .agg(plus_one(sum_udf(df.v)))
    +                   .sort('id'))
    +
    +        expected1 = (df.groupby('id')
    +                     
.agg(plus_one(sum(df.v)).alias("plus_one(sum_udf(v))"))
    +                     .sort('id'))
    +
    +        result2 = (df.groupby('id')
    +                   .agg(sum_udf(plus_one(df.v)))
    +                   .sort('id'))
    +
    +        expected2 = (df.groupby('id')
    +                     .agg(sum(df.v + 1).alias("sum_udf(plus_one(v))"))
    +                     .sort('id'))
    +
    +        result3 = (df.groupby('id')
    +                   .agg(sum_udf(plus_two(df.v)))
    +                   .sort('id'))
    +
    +        expected3 = (df.groupby('id')
    +                     .agg(sum(df.v + 2).alias("sum_udf(plus_two(v))"))
    +                     .sort('id'))
    +
    +        result4 = (df.groupby('id')
    +                   .agg(plus_two(sum_udf(df.v)))
    +                   .sort('id'))
    +
    +        expected4 = (df.groupby('id')
    +                     
.agg(plus_two(sum(df.v)).alias("plus_two(sum_udf(v))"))
    +                     .sort('id'))
    +
    +        result5 = (df.groupby(plus_one(df.id))
    +                   .agg(plus_one(sum_udf(plus_one(df.v))))
    +                   .sort('plus_one(id)'))
    +        expected5 = (df.groupby(plus_one(df.id))
    +                     
.agg(plus_one(sum(plus_one(df.v))).alias('plus_one(sum_udf(plus_one(v)))'))
    +                     .sort('plus_one(id)'))
    +
    +        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
    +        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
    +        self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
    +        self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
    +        self.assertPandasEqual(expected5.toPandas(), result5.toPandas())
    +
    +    def test_multiple(self):
    +        from pyspark.sql.functions import col, lit, sum, mean
    +
    +        df = self.data
    +        mean_udf = self.mean_udf
    +        sum_udf = self.sum_udf
    +        weighted_mean_udf = self.weighted_mean_udf
    +
    +        result1 = (df.groupBy('id')
    +                   .agg(mean_udf(df.v),
    +                        sum_udf(df.v),
    +                        weighted_mean_udf(df.v, df.w))
    +                   .sort('id')
    +                   .toPandas())
    +
    +        expected1 = (df.groupBy('id')
    +                     .agg(mean(df.v).alias('mean_udf(v)'),
    +                          sum(df.v).alias('sum_udf(v)'),
    +                          mean(df.v).alias('weighted_mean_udf(v, w)'))
    +                     .sort('id')
    +                     .toPandas())
    +
    +        result2 = (df.groupBy('id', 'v')
    +                   .agg(mean_udf(df.v),
    +                        sum_udf(df.id))
    +                   .sort('id', 'v')
    +                   .toPandas())
    +
    +        expected2 = (df.groupBy('id', 'v')
    +                     .agg(mean_udf(df.v).alias('mean_udf(v)'),
    +                          sum_udf(df.id).alias('sum_udf(id)'))
    +                     .sort('id', 'v')
    +                     .toPandas())
    +
    +        self.assertPandasEqual(expected1, result1)
    +        self.assertPandasEqual(expected2, result2)
    +
    +    def test_complex_grouping(self):
    --- End diff --
    
    Changed to `test_complex_groupby`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to