Github user icexelloss commented on a diff in the pull request:

    https://github.com/apache/spark/pull/18732#discussion_r143740129
  
    --- Diff: python/pyspark/sql/tests.py ---
    @@ -3376,6 +3376,151 @@ def test_vectorized_udf_empty_partition(self):
             res = df.select(f(col('id')))
             self.assertEquals(df.collect(), res.collect())
     
    +    def test_vectorized_udf_varargs(self):
    +        from pyspark.sql.functions import pandas_udf, col
    +        df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 
2))
    +        f = pandas_udf(lambda *v: v[0], LongType())
    +        res = df.select(f(col('id')))
    +        self.assertEquals(df.collect(), res.collect())
    +
    +
    +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not 
installed")
    +class GroupbyApplyTests(ReusedPySparkTestCase):
    +    @classmethod
    +    def setUpClass(cls):
    +        ReusedPySparkTestCase.setUpClass()
    +        cls.spark = SparkSession(cls.sc)
    +
    +    @classmethod
    +    def tearDownClass(cls):
    +        ReusedPySparkTestCase.tearDownClass()
    +        cls.spark.stop()
    +
    +    def assertFramesEqual(self, expected, result):
    +        msg = ("DataFrames are not equal: " +
    +               ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) +
    +               ("\n\nResult:\n%s\n%s" % (result, result.dtypes)))
    +        self.assertTrue(expected.equals(result), msg=msg)
    +
    +    @property
    +    def data(self):
    +        from pyspark.sql.functions import array, explode, col, lit
    +        return self.spark.range(10).toDF('id') \
    +            .withColumn("vs", array([lit(i) for i in range(20, 30)])) \
    +            .withColumn("v", explode(col('vs'))).drop('vs')
    +
    +    def test_simple(self):
    +        from pyspark.sql.functions import pandas_udf
    +        df = self.data
    +
    +        foo_udf = pandas_udf(
    +            lambda df: df.assign(v1=df.v * df.id * 1.0, v2=df.v + df.id),
    +            StructType(
    +                [StructField('id', LongType()),
    +                 StructField('v', IntegerType()),
    +                 StructField('v1', DoubleType()),
    +                 StructField('v2', LongType())]))
    +
    +        result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
    +        expected = 
df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
    +        self.assertFramesEqual(expected, result)
    +
    +    def test_decorator(self):
    +        from pyspark.sql.functions import pandas_udf
    +        df = self.data
    +
    +        @pandas_udf(StructType(
    +            [StructField('id', LongType()),
    +             StructField('v', IntegerType()),
    +             StructField('v1', DoubleType()),
    +             StructField('v2', LongType())]))
    +        def foo(df):
    +            return df.assign(v1=df.v * df.id * 1.0, v2=df.v + df.id)
    +
    +        result = df.groupby('id').apply(foo).sort('id').toPandas()
    +        expected = 
df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
    +        self.assertFramesEqual(expected, result)
    +
    +    def test_coerce(self):
    +        from pyspark.sql.functions import pandas_udf
    +        df = self.data
    +
    +        foo = pandas_udf(
    +            lambda df: df,
    --- End diff --
    
    Fixed.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to