Michael Tong created SPARK-27691: ------------------------------------ Summary: Issue when running queries using filter predicates on pandas GROUPED_AGG udfs Key: SPARK-27691 URL: https://issues.apache.org/jira/browse/SPARK-27691 Project: Spark Issue Type: Bug Components: Input/Output Affects Versions: 2.4.2 Reporter: Michael Tong
Am currently running pyspark 2.4.2 and I am currently unable to run the following code. {code:java} from pyspark.sql import functions, types import pandas as pd import random # initialize test data test_data = [[False, int(random.random() * 2)] for i in range(10000)] test_data = pd.DataFrame(test_data, columns=['bool_value', 'int_value']) # pandas udf pandas_any_udf = functions.pandas_udf(lambda x: x.any(), types.BooleanType(), functions.PandasUDFType.GROUPED_AGG) # create spark DataFrame and build the query test_df = spark.createDataFrame(test_data) test_df = test_df.groupby('int_value').agg(pandas_any_udf('bool_value').alias('bool_any_result')) test_df = test_df.filter(functions.col('bool_any_result') == True) # write to output test_df.write.parquet('/tmp/mtong/write_test', mode='overwrite') {code} Below is a truncated error message. {code:java} Py4JJavaError: An error occurred while calling o1125.parquet. : org.apache.spark.SparkException: Job aborted. ... Exchange hashpartitioning(int_value#123L, 2000) +- *(1) Filter (<lambda>(bool_value#122) = true) +- Scan ExistingRDD arrow[bool_value#122,int_value#123L] ... Caused by: java.lang.UnsupportedOperationException: Cannot evaluate expression: <lambda>(input[0, boolean, true]){code} What appears to be happening is that the query optimizer incorrectly pushes up the filter predicate on bool_any_result before the group by operation. This causes the query to error out before spark attempts to execute the query. I have also tried running a variant of this query with functions.count() as the aggregation function and the query ran fine, so I believe that this is an error that only affects pandas udfs. Variant of query with standard aggregation function {code:java} test_df = spark.createDataFrame(test_data) test_df = test_df.groupby('int_value').agg(functions.count('bool_value').alias('bool_counts')) test_df = test_df.filter(functions.col('bool_counts') > 0) {code} -- This message was sent by Atlassian JIRA (v7.6.3#76005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org