[ 
https://issues.apache.org/jira/browse/SPARK-49793?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Weichen Xu reassigned SPARK-49793:
----------------------------------

    Assignee: Weichen Xu

> Enable PredictBatchUDFTests.test_caching for NumPy 2
> ----------------------------------------------------
>
>                 Key: SPARK-49793
>                 URL: https://issues.apache.org/jira/browse/SPARK-49793
>             Project: Spark
>          Issue Type: Story
>          Components: ML, Tests
>    Affects Versions: 4.0.0
>            Reporter: Xinrong Meng
>            Assignee: Weichen Xu
>            Priority: Major
>
>  
> {code:java}
> import numpy as np
> import pandas as pd
> from pyspark.ml.functions import predict_batch_udf
> from pyspark.sql.types import DoubleType
> from pyspark.sql.functions import struct
> data = np.arange(0, 36, dtype=np.float64).reshape(-1, 4)
> pdf = pd.DataFrame(data, columns=["a", "b", "c", "d"])
> df = spark.createDataFrame(pdf)
> def make_predict_fn():
>     fake_output = np.random.random()
>     def predict(inputs):
>         return np.array([fake_output for i in inputs])
>     return predict
>  
> identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(), 
> batch_size=5)
> df1 = df.withColumn("preds", identity(struct("a"))).toPandas()
> df2 = df.withColumn("preds", identity(struct("a"))).toPandas()
> {code}
> NumPy 2.1.0
> {code:java}
> >>> df1
>       a     b     c     d     preds
> 0   0.0   1.0   2.0   3.0  0.431752
> 1   4.0   5.0   6.0   7.0  0.912097
> 2   8.0   9.0  10.0  11.0  0.679628
> 3  12.0  13.0  14.0  15.0  0.853850
> 4  16.0  17.0  18.0  19.0  0.389971
> 5  20.0  21.0  22.0  23.0  0.654521
> 6  24.0  25.0  26.0  27.0  0.430569
> 7  28.0  29.0  30.0  31.0  0.331055
> 8  32.0  33.0  34.0  35.0  0.306073
> >>> df2
>       a     b     c     d     preds
> 0   0.0   1.0   2.0   3.0  0.679628
> 1   4.0   5.0   6.0   7.0  0.430569
> 2   8.0   9.0  10.0  11.0  0.853850
> 3  12.0  13.0  14.0  15.0  0.306073
> 4  16.0  17.0  18.0  19.0  0.654521
> 5  20.0  21.0  22.0  23.0  0.389971
> 6  24.0  25.0  26.0  27.0  0.507598
> 7  28.0  29.0  30.0  31.0  0.912097
> 8  32.0  33.0  34.0  35.0  0.431752 {code}
> which should be
> {code:java}
> >>> df1
>       a     b     c     d     preds
> 0   0.0   1.0   2.0   3.0  0.685941
> 1   4.0   5.0   6.0   7.0  0.685941
> 2   8.0   9.0  10.0  11.0  0.685941
> 3  12.0  13.0  14.0  15.0  0.685941
> 4  16.0  17.0  18.0  19.0  0.685941
> 5  20.0  21.0  22.0  23.0  0.685941
> 6  24.0  25.0  26.0  27.0  0.685941
> 7  28.0  29.0  30.0  31.0  0.685941
> 8  32.0  33.0  34.0  35.0  0.685941
> >>> df2
>       a     b     c     d     preds
> 0   0.0   1.0   2.0   3.0  0.685941
> 1   4.0   5.0   6.0   7.0  0.685941
> 2   8.0   9.0  10.0  11.0  0.685941
> 3  12.0  13.0  14.0  15.0  0.685941
> 4  16.0  17.0  18.0  19.0  0.685941
> 5  20.0  21.0  22.0  23.0  0.685941
> 6  24.0  25.0  26.0  27.0  0.685941
> 7  28.0  29.0  30.0  31.0  0.685941
> 8  32.0  33.0  34.0  35.0  0.685941 {code}
>  



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to