Repository: spark
Updated Branches:
  refs/heads/master d2b2932d8 -> 3e6a714c9


[SPARK-21766][PYSPARK][SQL] DataFrame toPandas() raises ValueError with 
nullable int columns

## What changes were proposed in this pull request?

When calling `DataFrame.toPandas()` (without Arrow enabled), if there is a 
`IntegralType` column (`IntegerType`, `ShortType`, `ByteType`) that has null 
values the following exception is thrown:

    ValueError: Cannot convert non-finite values (NA or inf) to integer

This is because the null values first get converted to float NaN during the 
construction of the Pandas DataFrame in `from_records`, and then it is 
attempted to be converted back to to an integer where it fails.

The fix is going to check if the Pandas DataFrame can cause such failure when 
converting, if so, we don't do the conversion and use the inferred type by 
Pandas.

Closes #18945

## How was this patch tested?

Added pyspark test.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #19319 from viirya/SPARK-21766.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3e6a714c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3e6a714c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3e6a714c

Branch: refs/heads/master
Commit: 3e6a714c9ee97ef13b3f2010babded3b63fd9d74
Parents: d2b2932
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Fri Sep 22 22:39:47 2017 +0900
Committer: hyukjinkwon <gurwls...@gmail.com>
Committed: Fri Sep 22 22:39:47 2017 +0900

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py | 13 ++++++++++---
 python/pyspark/sql/tests.py     | 12 ++++++++++++
 2 files changed, 22 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3e6a714c/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 88ac413..7b81a0b 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -37,6 +37,7 @@ from pyspark.sql.types import _parse_datatype_json_string
 from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
 from pyspark.sql.readwriter import DataFrameWriter
 from pyspark.sql.streaming import DataStreamWriter
+from pyspark.sql.types import IntegralType
 from pyspark.sql.types import *
 
 __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"]
@@ -1891,14 +1892,20 @@ class DataFrame(object):
                       "if using spark.sql.execution.arrow.enable=true"
                 raise ImportError("%s\n%s" % (e.message, msg))
         else:
+            pdf = pd.DataFrame.from_records(self.collect(), 
columns=self.columns)
+
             dtype = {}
             for field in self.schema:
                 pandas_type = _to_corrected_pandas_type(field.dataType)
-                if pandas_type is not None:
+                # SPARK-21766: if an integer field is nullable and has null 
values, it can be
+                # inferred by pandas as float column. Once we convert the 
column with NaN back
+                # to integer type e.g., np.int16, we will hit exception. So we 
use the inferred
+                # float type, not the corrected type from the schema in this 
case.
+                if pandas_type is not None and \
+                    not(isinstance(field.dataType, IntegralType) and 
field.nullable and
+                        pdf[field.name].isnull().any()):
                     dtype[field.name] = pandas_type
 
-            pdf = pd.DataFrame.from_records(self.collect(), 
columns=self.columns)
-
             for f, t in dtype.items():
                 pdf[f] = pdf[f].astype(t, copy=False)
             return pdf

http://git-wip-us.apache.org/repos/asf/spark/blob/3e6a714c/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index ab76c48..3db8bee 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -2564,6 +2564,18 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEquals(types[2], np.bool)
         self.assertEquals(types[3], np.float32)
 
+    @unittest.skipIf(not _have_pandas, "Pandas not installed")
+    def test_to_pandas_avoid_astype(self):
+        import numpy as np
+        schema = StructType().add("a", IntegerType()).add("b", StringType())\
+                             .add("c", IntegerType())
+        data = [(1, "foo", 16777220), (None, "bar", None)]
+        df = self.spark.createDataFrame(data, schema)
+        types = df.toPandas().dtypes
+        self.assertEquals(types[0], np.float64)  # doesn't convert to np.int32 
due to NaN value.
+        self.assertEquals(types[1], np.object)
+        self.assertEquals(types[2], np.float64)
+
     def test_create_dataframe_from_array_of_long(self):
         import array
         data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 
9223372036854775807]))]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to