[ https://issues.apache.org/jira/browse/SPARK-49793?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Xinrong Meng updated SPARK-49793: --------------------------------- Description: {code:java} import numpy as np import pandas as pd from pyspark.ml.functions import predict_batch_udf from pyspark.sql.types import DoubleType from pyspark.sql.functions import struct data = np.arange(0, 36, dtype=np.float64).reshape(-1, 4) pdf = pd.DataFrame(data, columns=["a", "b", "c", "d"]) df = spark.createDataFrame(pdf) def make_predict_fn(): fake_output = np.random.random() def predict(inputs): return np.array([fake_output for i in inputs]) return predict identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(), batch_size=5) df1 = df.withColumn("preds", identity(struct("a"))).toPandas() df2 = df.withColumn("preds", identity(struct("a"))).toPandas() {code} NumPy 2.1.0 {code:java} >>> df1 a b c d preds 0 0.0 1.0 2.0 3.0 0.431752 1 4.0 5.0 6.0 7.0 0.912097 2 8.0 9.0 10.0 11.0 0.679628 3 12.0 13.0 14.0 15.0 0.853850 4 16.0 17.0 18.0 19.0 0.389971 5 20.0 21.0 22.0 23.0 0.654521 6 24.0 25.0 26.0 27.0 0.430569 7 28.0 29.0 30.0 31.0 0.331055 8 32.0 33.0 34.0 35.0 0.306073 >>> df2 a b c d preds 0 0.0 1.0 2.0 3.0 0.679628 1 4.0 5.0 6.0 7.0 0.430569 2 8.0 9.0 10.0 11.0 0.853850 3 12.0 13.0 14.0 15.0 0.306073 4 16.0 17.0 18.0 19.0 0.654521 5 20.0 21.0 22.0 23.0 0.389971 6 24.0 25.0 26.0 27.0 0.507598 7 28.0 29.0 30.0 31.0 0.912097 8 32.0 33.0 34.0 35.0 0.431752 {code} which should be {code:java} >>> df1 a b c d preds 0 0.0 1.0 2.0 3.0 0.685941 1 4.0 5.0 6.0 7.0 0.685941 2 8.0 9.0 10.0 11.0 0.685941 3 12.0 13.0 14.0 15.0 0.685941 4 16.0 17.0 18.0 19.0 0.685941 5 20.0 21.0 22.0 23.0 0.685941 6 24.0 25.0 26.0 27.0 0.685941 7 28.0 29.0 30.0 31.0 0.685941 8 32.0 33.0 34.0 35.0 0.685941 >>> df2 a b c d preds 0 0.0 1.0 2.0 3.0 0.685941 1 4.0 5.0 6.0 7.0 0.685941 2 8.0 9.0 10.0 11.0 0.685941 3 12.0 13.0 14.0 15.0 0.685941 4 16.0 17.0 18.0 19.0 0.685941 5 20.0 21.0 22.0 23.0 0.685941 6 24.0 25.0 26.0 27.0 0.685941 7 28.0 29.0 30.0 31.0 0.685941 8 32.0 33.0 34.0 35.0 0.685941 {code} > Enable PredictBatchUDFTests.test_caching for NumPy 2 > ---------------------------------------------------- > > Key: SPARK-49793 > URL: https://issues.apache.org/jira/browse/SPARK-49793 > Project: Spark > Issue Type: Story > Components: ML, Tests > Affects Versions: 4.0.0 > Reporter: Xinrong Meng > Priority: Major > > > {code:java} > import numpy as np > import pandas as pd > from pyspark.ml.functions import predict_batch_udf > from pyspark.sql.types import DoubleType > from pyspark.sql.functions import struct > data = np.arange(0, 36, dtype=np.float64).reshape(-1, 4) > pdf = pd.DataFrame(data, columns=["a", "b", "c", "d"]) > df = spark.createDataFrame(pdf) > def make_predict_fn(): > fake_output = np.random.random() > def predict(inputs): > return np.array([fake_output for i in inputs]) > return predict > > identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(), > batch_size=5) > df1 = df.withColumn("preds", identity(struct("a"))).toPandas() > df2 = df.withColumn("preds", identity(struct("a"))).toPandas() > {code} > NumPy 2.1.0 > {code:java} > >>> df1 > a b c d preds > 0 0.0 1.0 2.0 3.0 0.431752 > 1 4.0 5.0 6.0 7.0 0.912097 > 2 8.0 9.0 10.0 11.0 0.679628 > 3 12.0 13.0 14.0 15.0 0.853850 > 4 16.0 17.0 18.0 19.0 0.389971 > 5 20.0 21.0 22.0 23.0 0.654521 > 6 24.0 25.0 26.0 27.0 0.430569 > 7 28.0 29.0 30.0 31.0 0.331055 > 8 32.0 33.0 34.0 35.0 0.306073 > >>> df2 > a b c d preds > 0 0.0 1.0 2.0 3.0 0.679628 > 1 4.0 5.0 6.0 7.0 0.430569 > 2 8.0 9.0 10.0 11.0 0.853850 > 3 12.0 13.0 14.0 15.0 0.306073 > 4 16.0 17.0 18.0 19.0 0.654521 > 5 20.0 21.0 22.0 23.0 0.389971 > 6 24.0 25.0 26.0 27.0 0.507598 > 7 28.0 29.0 30.0 31.0 0.912097 > 8 32.0 33.0 34.0 35.0 0.431752 {code} > which should be > {code:java} > >>> df1 > a b c d preds > 0 0.0 1.0 2.0 3.0 0.685941 > 1 4.0 5.0 6.0 7.0 0.685941 > 2 8.0 9.0 10.0 11.0 0.685941 > 3 12.0 13.0 14.0 15.0 0.685941 > 4 16.0 17.0 18.0 19.0 0.685941 > 5 20.0 21.0 22.0 23.0 0.685941 > 6 24.0 25.0 26.0 27.0 0.685941 > 7 28.0 29.0 30.0 31.0 0.685941 > 8 32.0 33.0 34.0 35.0 0.685941 > >>> df2 > a b c d preds > 0 0.0 1.0 2.0 3.0 0.685941 > 1 4.0 5.0 6.0 7.0 0.685941 > 2 8.0 9.0 10.0 11.0 0.685941 > 3 12.0 13.0 14.0 15.0 0.685941 > 4 16.0 17.0 18.0 19.0 0.685941 > 5 20.0 21.0 22.0 23.0 0.685941 > 6 24.0 25.0 26.0 27.0 0.685941 > 7 28.0 29.0 30.0 31.0 0.685941 > 8 32.0 33.0 34.0 35.0 0.685941 {code} > -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org