Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/22568#discussion_r220944429 --- Diff: python/pyspark/sql/tests.py --- @@ -5525,32 +5525,73 @@ def data(self): .withColumn("v", explode(col('vs'))).drop('vs') def test_supported_types(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col - df = self.data.withColumn("arr", array(col("id"))) + from decimal import Decimal + from distutils.version import LooseVersion + import pyarrow as pa + from pyspark.sql.functions import pandas_udf, PandasUDFType - # Different forms of group map pandas UDF, results of these are the same + input_values_with_schema = [ + (1, StructField('id', IntegerType())), + (2, StructField('byte', ByteType())), + (3, StructField('short', ShortType())), + (4, StructField('int', IntegerType())), + (5, StructField('long', LongType())), + (1.1, StructField('float', FloatType())), + (2.2, StructField('double', DoubleType())), + (Decimal(1.123), StructField('decim', DecimalType(10, 3))), + ([1, 2, 3], StructField('array', ArrayType(IntegerType()))), + (True, StructField('bool', BooleanType())), + ('hello', StructField('str', StringType())), + ] - output_schema = StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('arr', ArrayType(LongType())), - StructField('v1', DoubleType()), - StructField('v2', LongType())]) + # TODO: Add BinaryType to 'input_values_with_schema' once minimum pyarrow version is 0.10.0 + if LooseVersion(pa.__version__) >= LooseVersion("0.10.0"): + input_values_with_schema.append( + (bytearray([0x01, 0x02]), StructField('bin', BinaryType())) + ) + + values = [[x[0] for x in input_values_with_schema]] + output_schema = StructType([x[1] for x in input_values_with_schema]) + df = self.spark.createDataFrame(values, schema=output_schema) + + # Different forms of group map pandas UDF, results of these are the same udf1 = pandas_udf( - lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + lambda pdf: pdf.assign( + decim=pdf.decim + pdf.decim, + double=pdf.double + pdf.float, + byte=pdf.byte + 1, + long=pdf.byte + pdf.int + pdf.long + pdf.short, --- End diff -- I would get rid of those calculations with different types to make it easier to read.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org