[ 
https://issues.apache.org/jira/browse/SPARK-49793?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Xinrong Meng updated SPARK-49793:
---------------------------------
    Description: 
 
{code:java}
import numpy as np
import pandas as pd
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import DoubleType
from pyspark.sql.functions import struct
data = np.arange(0, 36, dtype=np.float64).reshape(-1, 4)
pdf = pd.DataFrame(data, columns=["a", "b", "c", "d"])
df = spark.createDataFrame(pdf)
def make_predict_fn():
    fake_output = np.random.random()
    def predict(inputs):
        return np.array([fake_output for i in inputs])
    return predict
 
identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(), 
batch_size=5)
df1 = df.withColumn("preds", identity(struct("a"))).toPandas()
df2 = df.withColumn("preds", identity(struct("a"))).toPandas()
{code}
NumPy 2.1.0
{code:java}
>>> df1
      a     b     c     d     preds
0   0.0   1.0   2.0   3.0  0.431752
1   4.0   5.0   6.0   7.0  0.912097
2   8.0   9.0  10.0  11.0  0.679628
3  12.0  13.0  14.0  15.0  0.853850
4  16.0  17.0  18.0  19.0  0.389971
5  20.0  21.0  22.0  23.0  0.654521
6  24.0  25.0  26.0  27.0  0.430569
7  28.0  29.0  30.0  31.0  0.331055
8  32.0  33.0  34.0  35.0  0.306073
>>> df2
      a     b     c     d     preds
0   0.0   1.0   2.0   3.0  0.679628
1   4.0   5.0   6.0   7.0  0.430569
2   8.0   9.0  10.0  11.0  0.853850
3  12.0  13.0  14.0  15.0  0.306073
4  16.0  17.0  18.0  19.0  0.654521
5  20.0  21.0  22.0  23.0  0.389971
6  24.0  25.0  26.0  27.0  0.507598
7  28.0  29.0  30.0  31.0  0.912097
8  32.0  33.0  34.0  35.0  0.431752 {code}
which should be
{code:java}
>>> df1
      a     b     c     d     preds
0   0.0   1.0   2.0   3.0  0.685941
1   4.0   5.0   6.0   7.0  0.685941
2   8.0   9.0  10.0  11.0  0.685941
3  12.0  13.0  14.0  15.0  0.685941
4  16.0  17.0  18.0  19.0  0.685941
5  20.0  21.0  22.0  23.0  0.685941
6  24.0  25.0  26.0  27.0  0.685941
7  28.0  29.0  30.0  31.0  0.685941
8  32.0  33.0  34.0  35.0  0.685941
>>> df2
      a     b     c     d     preds
0   0.0   1.0   2.0   3.0  0.685941
1   4.0   5.0   6.0   7.0  0.685941
2   8.0   9.0  10.0  11.0  0.685941
3  12.0  13.0  14.0  15.0  0.685941
4  16.0  17.0  18.0  19.0  0.685941
5  20.0  21.0  22.0  23.0  0.685941
6  24.0  25.0  26.0  27.0  0.685941
7  28.0  29.0  30.0  31.0  0.685941
8  32.0  33.0  34.0  35.0  0.685941 {code}
 

> Enable PredictBatchUDFTests.test_caching for NumPy 2
> ----------------------------------------------------
>
>                 Key: SPARK-49793
>                 URL: https://issues.apache.org/jira/browse/SPARK-49793
>             Project: Spark
>          Issue Type: Story
>          Components: ML, Tests
>    Affects Versions: 4.0.0
>            Reporter: Xinrong Meng
>            Priority: Major
>
>  
> {code:java}
> import numpy as np
> import pandas as pd
> from pyspark.ml.functions import predict_batch_udf
> from pyspark.sql.types import DoubleType
> from pyspark.sql.functions import struct
> data = np.arange(0, 36, dtype=np.float64).reshape(-1, 4)
> pdf = pd.DataFrame(data, columns=["a", "b", "c", "d"])
> df = spark.createDataFrame(pdf)
> def make_predict_fn():
>     fake_output = np.random.random()
>     def predict(inputs):
>         return np.array([fake_output for i in inputs])
>     return predict
>  
> identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(), 
> batch_size=5)
> df1 = df.withColumn("preds", identity(struct("a"))).toPandas()
> df2 = df.withColumn("preds", identity(struct("a"))).toPandas()
> {code}
> NumPy 2.1.0
> {code:java}
> >>> df1
>       a     b     c     d     preds
> 0   0.0   1.0   2.0   3.0  0.431752
> 1   4.0   5.0   6.0   7.0  0.912097
> 2   8.0   9.0  10.0  11.0  0.679628
> 3  12.0  13.0  14.0  15.0  0.853850
> 4  16.0  17.0  18.0  19.0  0.389971
> 5  20.0  21.0  22.0  23.0  0.654521
> 6  24.0  25.0  26.0  27.0  0.430569
> 7  28.0  29.0  30.0  31.0  0.331055
> 8  32.0  33.0  34.0  35.0  0.306073
> >>> df2
>       a     b     c     d     preds
> 0   0.0   1.0   2.0   3.0  0.679628
> 1   4.0   5.0   6.0   7.0  0.430569
> 2   8.0   9.0  10.0  11.0  0.853850
> 3  12.0  13.0  14.0  15.0  0.306073
> 4  16.0  17.0  18.0  19.0  0.654521
> 5  20.0  21.0  22.0  23.0  0.389971
> 6  24.0  25.0  26.0  27.0  0.507598
> 7  28.0  29.0  30.0  31.0  0.912097
> 8  32.0  33.0  34.0  35.0  0.431752 {code}
> which should be
> {code:java}
> >>> df1
>       a     b     c     d     preds
> 0   0.0   1.0   2.0   3.0  0.685941
> 1   4.0   5.0   6.0   7.0  0.685941
> 2   8.0   9.0  10.0  11.0  0.685941
> 3  12.0  13.0  14.0  15.0  0.685941
> 4  16.0  17.0  18.0  19.0  0.685941
> 5  20.0  21.0  22.0  23.0  0.685941
> 6  24.0  25.0  26.0  27.0  0.685941
> 7  28.0  29.0  30.0  31.0  0.685941
> 8  32.0  33.0  34.0  35.0  0.685941
> >>> df2
>       a     b     c     d     preds
> 0   0.0   1.0   2.0   3.0  0.685941
> 1   4.0   5.0   6.0   7.0  0.685941
> 2   8.0   9.0  10.0  11.0  0.685941
> 3  12.0  13.0  14.0  15.0  0.685941
> 4  16.0  17.0  18.0  19.0  0.685941
> 5  20.0  21.0  22.0  23.0  0.685941
> 6  24.0  25.0  26.0  27.0  0.685941
> 7  28.0  29.0  30.0  31.0  0.685941
> 8  32.0  33.0  34.0  35.0  0.685941 {code}
>  



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to