[ https://issues.apache.org/jira/browse/SPARK-43298?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Lee Yang updated SPARK-43298: ----------------------------- Description: This is related to SPARK-42250. For scalar inputs, the predict_batch_udf will fail if the batch size is 1: {code:java} import numpy as np from pyspark.ml.functions import predict_batch_udf from pyspark.sql.types import DoubleType df = spark.createDataFrame([[1.0],[2.0]], schema=["a"]) def make_predict_fn(): def predict(inputs): return inputs return predict identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(), batch_size=1) preds = df.withColumn("preds", identity("a")).collect() {code} fails with: {code:java} File "/.../spark/python/pyspark/worker.py", line 869, in main process() File "/.../spark/python/pyspark/worker.py", line 861, in process serializer.dump_stream(out_iter, outfile) File "/.../spark/python/pyspark/sql/pandas/serializers.py", line 354, in dump_stream return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) File "/.../spark/python/pyspark/sql/pandas/serializers.py", line 86, in dump_stream for batch in iterator: File "/.../spark/python/pyspark/sql/pandas/serializers.py", line 347, in init_stream_yield_batches for series in iterator: File "/.../spark/python/pyspark/worker.py", line 555, in func for result_batch, result_type in result_iter: File "/.../spark/python/pyspark/ml/functions.py", line 818, in predict yield _validate_and_transform_prediction_result( File "/.../spark/python/pyspark/ml/functions.py", line 339, in _validate_and_transform_prediction_result if len(preds_array) != num_input_rows: TypeError: len() of unsized object {code} was: This is related to SPARK-42250. For scalar inputs, the predict_batch_udf will fail if the batch size is 1: {code} import numpy as np from pyspark.ml.functions import predict_batch_udf from pyspark.sql.types import DoubleType df = spark.createDataFrame([[1.0],[2.0]], schema=["a"]) def make_predict_fn(): def predict(inputs): return inputs return predict identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(), batch_size=1) preds = df.withColumn("preds", identity("a")).collect() {code} fails with: {code} File "/home/leey/devpub/spark/python/pyspark/worker.py", line 869, in main process() File "/home/leey/devpub/spark/python/pyspark/worker.py", line 861, in process serializer.dump_stream(out_iter, outfile) File "/home/leey/devpub/spark/python/pyspark/sql/pandas/serializers.py", line 354, in dump_stream return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) File "/home/leey/devpub/spark/python/pyspark/sql/pandas/serializers.py", line 86, in dump_stream for batch in iterator: File "/home/leey/devpub/spark/python/pyspark/sql/pandas/serializers.py", line 347, in init_stream_yield_batches for series in iterator: File "/home/leey/devpub/spark/python/pyspark/worker.py", line 555, in func for result_batch, result_type in result_iter: File "/home/leey/devpub/spark/python/pyspark/ml/functions.py", line 818, in predict yield _validate_and_transform_prediction_result( File "/home/leey/devpub/spark/python/pyspark/ml/functions.py", line 339, in _validate_and_transform_prediction_result if len(preds_array) != num_input_rows: TypeError: len() of unsized object {code} > predict_batch_udf with scalar input fails when batch size consists of a > single value > ------------------------------------------------------------------------------------ > > Key: SPARK-43298 > URL: https://issues.apache.org/jira/browse/SPARK-43298 > Project: Spark > Issue Type: Bug > Components: ML, PySpark > Affects Versions: 3.4.0 > Reporter: Lee Yang > Priority: Major > > This is related to SPARK-42250. For scalar inputs, the predict_batch_udf > will fail if the batch size is 1: > {code:java} > import numpy as np > from pyspark.ml.functions import predict_batch_udf > from pyspark.sql.types import DoubleType > df = spark.createDataFrame([[1.0],[2.0]], schema=["a"]) > def make_predict_fn(): > def predict(inputs): > return inputs > return predict > identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(), > batch_size=1) > preds = df.withColumn("preds", identity("a")).collect() > {code} > fails with: > {code:java} > File "/.../spark/python/pyspark/worker.py", line 869, in main > process() > File "/.../spark/python/pyspark/worker.py", line 861, in process > serializer.dump_stream(out_iter, outfile) > File "/.../spark/python/pyspark/sql/pandas/serializers.py", line 354, in > dump_stream > return ArrowStreamSerializer.dump_stream(self, > init_stream_yield_batches(), stream) > File "/.../spark/python/pyspark/sql/pandas/serializers.py", line 86, in > dump_stream > for batch in iterator: > File "/.../spark/python/pyspark/sql/pandas/serializers.py", line 347, in > init_stream_yield_batches > for series in iterator: > File "/.../spark/python/pyspark/worker.py", line 555, in func > for result_batch, result_type in result_iter: > File "/.../spark/python/pyspark/ml/functions.py", line 818, in predict > yield _validate_and_transform_prediction_result( > File "/.../spark/python/pyspark/ml/functions.py", line 339, in > _validate_and_transform_prediction_result > if len(preds_array) != num_input_rows: > TypeError: len() of unsized object > {code} -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org