[ 
https://issues.apache.org/jira/browse/SPARK-24760?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16537864#comment-16537864
 ] 

Mortada Mehyar commented on SPARK-24760:
----------------------------------------

Setting the "new" column to be nullable indeed makes it not raise an exception, 
however the NaN value gets turned into NULL, which I'd argue is still not the 
correct behavior
{code:java}
In [6]: @pandas_udf("key: string, value: int, new: double", 
PandasUDFType.GROUPED_MAP)
   ...: def func(pdf):
   ...:     pdf['new'] = pdf['value'].where(pdf['value'] > 0, float('nan'))
   ...:     return pdf


In [7]:

In [7]: df.groupby('key').apply(func).show()
+---+-----+----+
|key|value| new|
+---+-----+----+
|  b|    3| 3.0|
|  b|   -2|null|
|  a|    1| 1.0|
|  a|    2| 2.0|
+---+-----+----+{code}

> Pandas UDF does not handle NaN correctly
> ----------------------------------------
>
>                 Key: SPARK-24760
>                 URL: https://issues.apache.org/jira/browse/SPARK-24760
>             Project: Spark
>          Issue Type: Bug
>          Components: PySpark
>    Affects Versions: 2.3.0, 2.3.1
>         Environment: Spark 2.3.1
> Pandas 0.23.1
>            Reporter: Mortada Mehyar
>            Priority: Minor
>
> I noticed that having `NaN` values when using the new Pandas UDF feature 
> triggers a JVM exception. Not sure if this is an issue with PySpark or 
> PyArrow. Here is a somewhat contrived example to showcase the problem.
> {code}
> In [1]: import pandas as pd
>    ...: from pyspark.sql.functions import lit, pandas_udf, PandasUDFType
> In [2]: d = [{'key': 'a', 'value': 1},
>              {'key': 'a', 'value': 2},
>              {'key': 'b', 'value': 3},
>              {'key': 'b', 'value': -2}]
>         df = spark.createDataFrame(d, "key: string, value: int")
>         df.show()
> +---+-----+
> |key|value|
> +---+-----+
> |  a|    1|
> |  a|    2|
> |  b|    3|
> |  b|   -2|
> +---+-----+
> In [3]: df_tmp = df.withColumn('new', lit(1.0))  # add a DoubleType column
>         df_tmp.printSchema()
> root
>  |-- key: string (nullable = true)
>  |-- value: integer (nullable = true)
>  |-- new: double (nullable = false)
> {code}
> And the Pandas UDF is simply creating a new column where negative values 
> would be set to a particular float, in this case INF and it works fine
> {code}
> In [4]: @pandas_udf(df_tmp.schema, PandasUDFType.GROUPED_MAP)
>    ...: def func(pdf):
>    ...:     pdf['new'] = pdf['value'].where(pdf['value'] > 0, float('inf'))
>    ...:     return pdf
> In [5]: df.groupby('key').apply(func).show()
> +---+-----+----------+
> |key|value|new|
> +---+-----+----------+
> |  b|    3|       3.0|
> |  b|   -2|  Infinity|
> |  a|    1|       1.0|
> |  a|    2|       2.0|
> +---+-----+----------+
> {code}
> However if we set this value to NaN then it triggers an exception:
> {code}
> In [6]: @pandas_udf(df_tmp.schema, PandasUDFType.GROUPED_MAP)
>    ...: def func(pdf):
>    ...:     pdf['new'] = pdf['value'].where(pdf['value'] > 0, float('nan'))
>    ...:     return pdf
>    ...:
>    ...: df.groupby('key').apply(func).show()
> [Stage 23:======================================================> (73 + 2) / 
> 75]2018-07-07 16:26:27 ERROR Executor:91 - Exception in task 36.0 in stage 
> 23.0 (TID 414)
> java.lang.IllegalStateException: Value at index is null
>       at org.apache.arrow.vector.Float8Vector.get(Float8Vector.java:98)
>       at 
> org.apache.spark.sql.vectorized.ArrowColumnVector$DoubleAccessor.getDouble(ArrowColumnVector.java:344)
>       at 
> org.apache.spark.sql.vectorized.ArrowColumnVector.getDouble(ArrowColumnVector.java:99)
>       at 
> org.apache.spark.sql.execution.vectorized.MutableColumnarRow.getDouble(MutableColumnarRow.java:126)
>       at 
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
>  Source)
>       at 
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
>  Source)
>       at scala.collection.Iterator$$anon$11.next(Iterator.scala:409)
>       at 
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown
>  Source)
>       at 
> org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
>       at 
> org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
>       at 
> org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
>       at 
> org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
>       at 
> org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
>       at 
> org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
>       at 
> org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
>       at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
>       at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
>       at 
> org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
>       at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
>       at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
>       at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
>       at org.apache.spark.scheduler.Task.run(Task.scala:109)
>       at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
>       at 
> java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
>       at 
> java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
>       at java.lang.Thread.run(Thread.java:745)
> 2018-07-07 16:26:27 WARN  TaskSetManager:66 - Lost task 36.0 in stage 23.0 
> (TID 414, localhost, executor driver): java.lang.IllegalStateException: Value 
> at index is null
>       at org.apache.arrow.vector.Float8Vector.get(Float8Vector.java:98)
>       at 
> org.apache.spark.sql.vectorized.ArrowColumnVector$DoubleAccessor.getDouble(ArrowColumnVector.java:344)
>       at 
> org.apache.spark.sql.vectorized.ArrowColumnVector.getDouble(ArrowColumnVector.java:99)
>       at 
> org.apache.spark.sql.execution.vectorized.MutableColumnarRow.getDouble(MutableColumnarRow.java:126)
>       at 
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
>  Source)
>       at 
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
>  Source)
>       at scala.collection.Iterator$$anon$11.next(Iterator.scala:409)
>       at 
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown
>  Source)
>       at 
> org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
>       at 
> org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
>       at 
> org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
>       at 
> org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
>       at 
> org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
>       at 
> org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
>       at 
> org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
>       at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
>       at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
>       at 
> org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
>       at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
>       at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
>       at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
>       at org.apache.spark.scheduler.Task.run(Task.scala:109)
>       at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
>       at 
> java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
>       at 
> java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
>       at java.lang.Thread.run(Thread.java:745)
> {code}



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to