[ https://issues.apache.org/jira/browse/SPARK-49793?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Weichen Xu reassigned SPARK-49793: ---------------------------------- Assignee: Weichen Xu > Enable PredictBatchUDFTests.test_caching for NumPy 2 > ---------------------------------------------------- > > Key: SPARK-49793 > URL: https://issues.apache.org/jira/browse/SPARK-49793 > Project: Spark > Issue Type: Story > Components: ML, Tests > Affects Versions: 4.0.0 > Reporter: Xinrong Meng > Assignee: Weichen Xu > Priority: Major > > > {code:java} > import numpy as np > import pandas as pd > from pyspark.ml.functions import predict_batch_udf > from pyspark.sql.types import DoubleType > from pyspark.sql.functions import struct > data = np.arange(0, 36, dtype=np.float64).reshape(-1, 4) > pdf = pd.DataFrame(data, columns=["a", "b", "c", "d"]) > df = spark.createDataFrame(pdf) > def make_predict_fn(): > fake_output = np.random.random() > def predict(inputs): > return np.array([fake_output for i in inputs]) > return predict > > identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(), > batch_size=5) > df1 = df.withColumn("preds", identity(struct("a"))).toPandas() > df2 = df.withColumn("preds", identity(struct("a"))).toPandas() > {code} > NumPy 2.1.0 > {code:java} > >>> df1 > a b c d preds > 0 0.0 1.0 2.0 3.0 0.431752 > 1 4.0 5.0 6.0 7.0 0.912097 > 2 8.0 9.0 10.0 11.0 0.679628 > 3 12.0 13.0 14.0 15.0 0.853850 > 4 16.0 17.0 18.0 19.0 0.389971 > 5 20.0 21.0 22.0 23.0 0.654521 > 6 24.0 25.0 26.0 27.0 0.430569 > 7 28.0 29.0 30.0 31.0 0.331055 > 8 32.0 33.0 34.0 35.0 0.306073 > >>> df2 > a b c d preds > 0 0.0 1.0 2.0 3.0 0.679628 > 1 4.0 5.0 6.0 7.0 0.430569 > 2 8.0 9.0 10.0 11.0 0.853850 > 3 12.0 13.0 14.0 15.0 0.306073 > 4 16.0 17.0 18.0 19.0 0.654521 > 5 20.0 21.0 22.0 23.0 0.389971 > 6 24.0 25.0 26.0 27.0 0.507598 > 7 28.0 29.0 30.0 31.0 0.912097 > 8 32.0 33.0 34.0 35.0 0.431752 {code} > which should be > {code:java} > >>> df1 > a b c d preds > 0 0.0 1.0 2.0 3.0 0.685941 > 1 4.0 5.0 6.0 7.0 0.685941 > 2 8.0 9.0 10.0 11.0 0.685941 > 3 12.0 13.0 14.0 15.0 0.685941 > 4 16.0 17.0 18.0 19.0 0.685941 > 5 20.0 21.0 22.0 23.0 0.685941 > 6 24.0 25.0 26.0 27.0 0.685941 > 7 28.0 29.0 30.0 31.0 0.685941 > 8 32.0 33.0 34.0 35.0 0.685941 > >>> df2 > a b c d preds > 0 0.0 1.0 2.0 3.0 0.685941 > 1 4.0 5.0 6.0 7.0 0.685941 > 2 8.0 9.0 10.0 11.0 0.685941 > 3 12.0 13.0 14.0 15.0 0.685941 > 4 16.0 17.0 18.0 19.0 0.685941 > 5 20.0 21.0 22.0 23.0 0.685941 > 6 24.0 25.0 26.0 27.0 0.685941 > 7 28.0 29.0 30.0 31.0 0.685941 > 8 32.0 33.0 34.0 35.0 0.685941 {code} > -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org