Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/19630#discussion_r151523879 --- Diff: python/pyspark/sql/tests.py --- @@ -3166,6 +3166,92 @@ def test_filtered_frame(self): self.assertTrue(pdf.empty) +class PandasUDFTests(ReusedSQLTestCase): + def test_pandas_udf_basic(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + udf = pandas_udf(lambda x: x, DoubleType()) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + udf = pandas_udf(lambda x: x, 'v double', + functionType=PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + udf = pandas_udf(lambda x: x, returnType='v double', + functionType=PandasUDFType.GROUP_MAP) + self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) + self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + def test_pandas_udf_decorator(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, PandasUDFType + from pyspark.sql.types import StructType, StructField, DoubleType + + @pandas_udf(DoubleType()) + def foo(x): + return x + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + @pandas_udf(returnType=DoubleType()) + def foo(x): + return x + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + + schema = StructType([StructField("v", DoubleType())]) + + @pandas_udf(schema, PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + @pandas_udf(schema, functionType=PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUP_MAP) + def foo(x): + return x + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + + def test_udf_wrong_arg(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, 'return type'): + @pandas_udf(PandasUDFType.GROUP_MAP) + def foo(df): + return df + with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): + @pandas_udf(returnType=PandasUDFType.GROUP_MAP) + def foo(df): + return df + with self.assertRaisesRegexp(ValueError, 'Invalid returnType'): + @pandas_udf(returnType='double', functionType=PandasUDFType.GROUP_MAP) --- End diff -- oh i see
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org