Hyukjin Kwon created SPARK-42250:
------------------------------------

             Summary: batch_infer_udf with float fails when the batch size 
consists of single value
                 Key: SPARK-42250
                 URL: https://issues.apache.org/jira/browse/SPARK-42250
             Project: Spark
          Issue Type: Bug
          Components: ML, PySpark
    Affects Versions: 3.4.0
            Reporter: Hyukjin Kwon


{code}
import numpy as np
import pandas as pd
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import ArrayType, FloatType, StructType, StructField
from typing import Mapping

df = spark.createDataFrame([[[0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0]], [[4.0, 
5.0, 6.0, 7.0], [4.0, 5.0, 6.0]]], schema=["t1", "t2"])

def make_multi_sum_fn():
    def predict(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
        return np.sum(x1, axis=1) + np.sum(x2, axis=1)
    return predict

multi_sum_udf = predict_batch_udf(
    make_multi_sum_fn,
    return_type=FloatType(),
    batch_size=1,
    input_tensor_shapes=[[4], [3]],
)

df.select(multi_sum_udf("t1", "t2")).collect()
{code}

fails as below:

{code}
 File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 829, in main
    process()
  File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 821, in 
process
    serializer.dump_stream(out_iter, outfile)
  File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", 
line 345, in dump_stream
    return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), 
stream)
  File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", 
line 86, in dump_stream
    for batch in iterator:
  File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", 
line 339, in init_stream_yield_batches
    batch = self._create_batch(series)
  File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", 
line 275, in _create_batch
    arrs.append(create_array(s, t))
  File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", 
line 245, in create_array
    raise e
  File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", 
line 233, in create_array
    array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
  File "pyarrow/array.pxi", line 1044, in pyarrow.lib.Array.from_pandas
  File "pyarrow/array.pxi", line 316, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 83, in pyarrow.lib._ndarray_to_array
  File "pyarrow/error.pxi", line 100, in pyarrow.lib.check_status
pyarrow.lib.ArrowInvalid: Could not convert array(569.) with type 
numpy.ndarray: tried to convert to float32

        at 
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:554)
        at 
org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:118)
        at 
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:507)
        at 
org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
        at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
        at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
        at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown
 Source)
        at 
org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
        at 
org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
        at 
org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:391)
        at 
org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:888)
        at 
org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:888)
        at 
org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
        at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:92)
        at 
org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
        at org.apache.spark.scheduler.Task.run(Task.scala:139)
        at 
org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
        at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1520)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
        at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
        at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
        at java.lang.Thread.run(Thread.java:748)

{code}



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to