Dean Wyatte created SPARK-29990: ----------------------------------- Summary: Pandas UDF including numpy sort operation not working with multi-dimensional array column Key: SPARK-29990 URL: https://issues.apache.org/jira/browse/SPARK-29990 Project: Spark Issue Type: Bug Components: PySpark Affects Versions: 2.4.3 Reporter: Dean Wyatte
If I have a multi-dimensional array column in a Spark DataFrame (e.g., ArrayType(ArrayType(LongType())) ), I am unable to apply a Pandas UDF that uses a numpy sort operation. The function works fine on the native Pandas Series, but errors with a numpy error when called as a UDF with "ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()" Minimal code to reproduce: {code:java} import random import numpy as np import pandas as pd from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, LongType from pyspark.sql.functions import pandas_udf def ndsort(s): return pd.Series([x for x in np.sort(s.values.tolist(), axis=1)]) X = [random.sample([[i, 400000+i] for i in range(10)], 10) for j in range(100)] pdf = pd.DataFrame([[x] for x in X], columns=['x']) spark = SparkSession.builder.appName('test').getOrCreate() sdf = spark.createDataFrame(pdf) ndsort_udf = pandas_udf(ndsort, ArrayType(ArrayType(LongType()))) ndsort(pdf['x']) # works sdf.withColumn('y', ndsort_udf('x')).show() # errors{code} UDFs that use a numpy sort operation work fine if the column is only ArrayType(LongType()): {code:java} X = [random.sample([400000+i for i in range(10)], 10) for j in range(100)] pdf = pd.DataFrame([[x] for x in X], columns=['x']) sdf = spark.createDataFrame(pdf) ndsort_udf = pandas_udf(ndsort, ArrayType(LongType())) ndsort(pdf['x']) # works sdf.withColumn('y', ndsort_udf('x')).show() # also works {code} -- This message was sent by Atlassian Jira (v8.3.4#803005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org