Repository: spark
Updated Branches:
  refs/heads/branch-2.3 521494d7b -> 44933033e


[SPARK-23290][SQL][PYTHON][BACKPORT-2.3] Use datetime.date for date type when 
converting Spark DataFrame to Pandas DataFrame.

## What changes were proposed in this pull request?

This is a backport of #20506.

In #18664, there was a change in how `DateType` is being returned to users 
([line 1968 in 
dataframe.py](https://github.com/apache/spark/pull/18664/files#diff-6fc344560230bf0ef711bb9b5573f1faR1968)).
 This can cause client code which works in Spark 2.2 to fail.
See 
[SPARK-23290](https://issues.apache.org/jira/browse/SPARK-23290?focusedCommentId=16350917&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16350917)
 for an example.

This pr modifies to use `datetime.date` for date type as Spark 2.2 does.

## How was this patch tested?

Tests modified to fit the new behavior and existing tests.

Author: Takuya UESHIN <ues...@databricks.com>

Closes #20515 from ueshin/issues/SPARK-23290_2.3.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/44933033
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/44933033
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/44933033

Branch: refs/heads/branch-2.3
Commit: 44933033e9216ccb2e533b9dc6e6cb03cce39817
Parents: 521494d
Author: Takuya UESHIN <ues...@databricks.com>
Authored: Tue Feb 6 18:29:37 2018 +0900
Committer: hyukjinkwon <gurwls...@gmail.com>
Committed: Tue Feb 6 18:29:37 2018 +0900

----------------------------------------------------------------------
 python/pyspark/serializers.py   |  9 ++++--
 python/pyspark/sql/dataframe.py |  7 ++---
 python/pyspark/sql/tests.py     | 57 ++++++++++++++++++++++++++----------
 python/pyspark/sql/types.py     | 15 ++++++++++
 4 files changed, 66 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/44933033/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 88d6a19..e870325 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -267,12 +267,15 @@ class ArrowStreamPandasSerializer(Serializer):
         """
         Deserialize ArrowRecordBatches to an Arrow table and return as a list 
of pandas.Series.
         """
-        from pyspark.sql.types import _check_dataframe_localize_timestamps
+        from pyspark.sql.types import from_arrow_schema, 
_check_dataframe_convert_date, \
+            _check_dataframe_localize_timestamps
         import pyarrow as pa
         reader = pa.open_stream(stream)
+        schema = from_arrow_schema(reader.schema)
         for batch in reader:
-            # NOTE: changed from pa.Columns.to_pandas, timezone issue in 
conversion fixed in 0.7.1
-            pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), 
self._timezone)
+            pdf = batch.to_pandas()
+            pdf = _check_dataframe_convert_date(pdf, schema)
+            pdf = _check_dataframe_localize_timestamps(pdf, self._timezone)
             yield [c for _, c in pdf.iteritems()]
 
     def __repr__(self):

http://git-wip-us.apache.org/repos/asf/spark/blob/44933033/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 2e55407..59a4170 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1923,7 +1923,8 @@ class DataFrame(object):
 
         if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", 
"false").lower() == "true":
             try:
-                from pyspark.sql.types import 
_check_dataframe_localize_timestamps
+                from pyspark.sql.types import _check_dataframe_convert_date, \
+                    _check_dataframe_localize_timestamps
                 from pyspark.sql.utils import require_minimum_pyarrow_version
                 import pyarrow
                 require_minimum_pyarrow_version()
@@ -1931,6 +1932,7 @@ class DataFrame(object):
                 if tables:
                     table = pyarrow.concat_tables(tables)
                     pdf = table.to_pandas()
+                    pdf = _check_dataframe_convert_date(pdf, self.schema)
                     return _check_dataframe_localize_timestamps(pdf, timezone)
                 else:
                     return pd.DataFrame.from_records([], columns=self.columns)
@@ -2009,7 +2011,6 @@ def _to_corrected_pandas_type(dt):
     """
     When converting Spark SQL records to Pandas DataFrame, the inferred data 
type may be wrong.
     This method gets the corrected data type for Pandas if that type may be 
inferred uncorrectly.
-    NOTE: DateType is inferred incorrectly as 'object', TimestampType is 
correct with datetime64[ns]
     """
     import numpy as np
     if type(dt) == ByteType:
@@ -2020,8 +2021,6 @@ def _to_corrected_pandas_type(dt):
         return np.int32
     elif type(dt) == FloatType:
         return np.float32
-    elif type(dt) == DateType:
-        return 'datetime64[ns]'
     else:
         return None
 

http://git-wip-us.apache.org/repos/asf/spark/blob/44933033/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 0d5bc13..95b9c0e 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -2810,7 +2810,7 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEquals(types[1], np.object)
         self.assertEquals(types[2], np.bool)
         self.assertEquals(types[3], np.float32)
-        self.assertEquals(types[4], 'datetime64[ns]')
+        self.assertEquals(types[4], np.object)  # datetime.date
         self.assertEquals(types[5], 'datetime64[ns]')
 
     @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
@@ -3356,7 +3356,7 @@ class ArrowTests(ReusedSQLTestCase):
 
     @classmethod
     def setUpClass(cls):
-        from datetime import datetime
+        from datetime import date, datetime
         from decimal import Decimal
         ReusedSQLTestCase.setUpClass()
 
@@ -3378,11 +3378,11 @@ class ArrowTests(ReusedSQLTestCase):
             StructField("7_date_t", DateType(), True),
             StructField("8_timestamp_t", TimestampType(), True)])
         cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
-                     datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
+                     date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
                     (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
-                     datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
+                     date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
                     (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
-                     datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
+                     date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
 
     @classmethod
     def tearDownClass(cls):
@@ -3435,7 +3435,9 @@ class ArrowTests(ReusedSQLTestCase):
     def test_toPandas_arrow_toggle(self):
         df = self.spark.createDataFrame(self.data, schema=self.schema)
         pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
-        self.assertFramesEqual(pdf_arrow, pdf)
+        expected = self.create_pandas_data_frame()
+        self.assertFramesEqual(expected, pdf)
+        self.assertFramesEqual(expected, pdf_arrow)
 
     def test_toPandas_respect_session_timezone(self):
         df = self.spark.createDataFrame(self.data, schema=self.schema)
@@ -4036,18 +4038,42 @@ class VectorizedUDFTests(ReusedSQLTestCase):
             with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
                 df.select(f(col('map'))).collect()
 
-    def test_vectorized_udf_null_date(self):
+    def test_vectorized_udf_dates(self):
         from pyspark.sql.functions import pandas_udf, col
         from datetime import date
-        schema = StructType().add("date", DateType())
-        data = [(date(1969, 1, 1),),
-                (date(2012, 2, 2),),
-                (None,),
-                (date(2100, 4, 4),)]
+        schema = StructType().add("idx", LongType()).add("date", DateType())
+        data = [(0, date(1969, 1, 1),),
+                (1, date(2012, 2, 2),),
+                (2, None,),
+                (3, date(2100, 4, 4),)]
         df = self.spark.createDataFrame(data, schema=schema)
-        date_f = pandas_udf(lambda t: t, returnType=DateType())
-        res = df.select(date_f(col("date")))
-        self.assertEquals(df.collect(), res.collect())
+
+        date_copy = pandas_udf(lambda t: t, returnType=DateType())
+        df = df.withColumn("date_copy", date_copy(col("date")))
+
+        @pandas_udf(returnType=StringType())
+        def check_data(idx, date, date_copy):
+            import pandas as pd
+            msgs = []
+            is_equal = date.isnull()
+            for i in range(len(idx)):
+                if (is_equal[i] and data[idx[i]][1] is None) or \
+                        date[i] == data[idx[i]][1]:
+                    msgs.append(None)
+                else:
+                    msgs.append(
+                        "date values are not equal (date='%s': 
data[%d][1]='%s')"
+                        % (date[i], idx[i], data[idx[i]][1]))
+            return pd.Series(msgs)
+
+        result = df.withColumn("check_data",
+                               check_data(col("idx"), col("date"), 
col("date_copy"))).collect()
+
+        self.assertEquals(len(data), len(result))
+        for i in range(len(result)):
+            self.assertEquals(data[i][1], result[i][1])  # "date" col
+            self.assertEquals(data[i][1], result[i][2])  # "date_copy" col
+            self.assertIsNone(result[i][3])  # "check_data" col
 
     def test_vectorized_udf_timestamps(self):
         from pyspark.sql.functions import pandas_udf, col
@@ -4088,6 +4114,7 @@ class VectorizedUDFTests(ReusedSQLTestCase):
         self.assertEquals(len(data), len(result))
         for i in range(len(result)):
             self.assertEquals(data[i][1], result[i][1])  # "timestamp" col
+            self.assertEquals(data[i][1], result[i][2])  # "timestamp_copy" col
             self.assertIsNone(result[i][3])  # "check_data" col
 
     def test_vectorized_udf_return_timestamp_tz(self):

http://git-wip-us.apache.org/repos/asf/spark/blob/44933033/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 0dc5823..093dae5 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1694,6 +1694,21 @@ def from_arrow_schema(arrow_schema):
          for field in arrow_schema])
 
 
+def _check_dataframe_convert_date(pdf, schema):
+    """ Correct date type value to use datetime.date.
+
+    Pandas DataFrame created from PyArrow uses datetime64[ns] for date type 
values, but we should
+    use datetime.date to match the behavior with when Arrow optimization is 
disabled.
+
+    :param pdf: pandas.DataFrame
+    :param schema: a Spark schema of the pandas.DataFrame
+    """
+    for field in schema:
+        if type(field.dataType) == DateType:
+            pdf[field.name] = pdf[field.name].dt.date
+    return pdf
+
+
 def _check_dataframe_localize_timestamps(pdf, timezone):
     """
     Convert timezone aware timestamps to timezone-naive in the specified 
timezone or local timezone


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to