This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c5bcac  [SPARK-36626][PYTHON] Support TimestampNTZ in 
createDataFrame/toPandas and Python UDFs
9c5bcac is described below

commit 9c5bcac61ee56fbb271e890cc33f9a983612c5b0
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Thu Sep 2 14:00:27 2021 +0900

    [SPARK-36626][PYTHON] Support TimestampNTZ in createDataFrame/toPandas and 
Python UDFs
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to implement `TimestampNTZType` support in PySpark's 
`SparkSession.createDataFrame`, `DataFrame.toPandas`, Python UDFs, and pandas 
UDFs with and without Arrow.
    
    ### Why are the changes needed?
    
    To complete `TimestampNTZType` support.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    - Users now can use `TimestampNTZType` type in 
`SparkSession.createDataFrame`, `DataFrame.toPandas`, Python UDFs, and pandas 
UDFs with and without Arrow.
    
    - If `spark.sql.timestampType` is configured to `TIMESTAMP_NTZ`, PySpark 
will infer the `datetime` without timezone as `TimestampNTZType`. If it has a 
timezone, it will be inferred as `TimestampType` in 
`SparkSession.createDataFrame`.
    
        - If `TimestampType` and `TimestampNTZType` conflict during merging 
inferred schema, `TimestampType` has a higher precedence.
    
    - If the type is `TimestampNTZType`, treat this internally as an unknown 
timezone, and compute w/ UTC (same as JVM side), and avoid localization 
externally.
    
    ### How was this patch tested?
    
    Manually tested and unittests were added.
    
    Closes #33876 from HyukjinKwon/SPARK-36626.
    
    Lead-authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Co-authored-by: Dominik Gehl <d...@open.ch>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/docs/source/reference/pyspark.sql.rst       |  1 +
 python/pyspark/sql/pandas/conversion.py            | 17 ++++--
 python/pyspark/sql/pandas/conversion.pyi           |  1 +
 python/pyspark/sql/pandas/serializers.py           |  4 +-
 python/pyspark/sql/pandas/types.py                 |  8 ++-
 python/pyspark/sql/session.py                      | 20 +++++--
 python/pyspark/sql/tests/test_arrow.py             | 16 +++++-
 python/pyspark/sql/tests/test_dataframe.py         | 34 ++++++++----
 python/pyspark/sql/tests/test_pandas_udf.py        | 18 +++++++
 python/pyspark/sql/tests/test_types.py             | 16 +++++-
 python/pyspark/sql/tests/test_udf.py               | 20 ++++++-
 python/pyspark/sql/types.py                        | 62 +++++++++++++++++-----
 python/pyspark/sql/types.pyi                       |  5 ++
 .../sql/execution/python/EvaluatePython.scala      |  4 +-
 14 files changed, 185 insertions(+), 41 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql.rst 
b/python/docs/source/reference/pyspark.sql.rst
index d72332a..605a150 100644
--- a/python/docs/source/reference/pyspark.sql.rst
+++ b/python/docs/source/reference/pyspark.sql.rst
@@ -298,6 +298,7 @@ Data Types
     StringType
     StructField
     StructType
+    TimestampNTZType
     TimestampType
 
 
diff --git a/python/pyspark/sql/pandas/conversion.py 
b/python/pyspark/sql/pandas/conversion.py
index 92ef7ce..5454410 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -22,7 +22,7 @@ from pyspark.rdd import _load_from_socket
 from pyspark.sql.pandas.serializers import ArrowCollectSerializer
 from pyspark.sql.types import IntegralType
 from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, 
FloatType, \
-    DoubleType, BooleanType, MapType, TimestampType, StructType, DataType
+    DoubleType, BooleanType, MapType, TimestampType, TimestampNTZType, 
StructType, DataType
 from pyspark.traceback_utils import SCCallSiteSync
 
 
@@ -238,6 +238,8 @@ class PandasConversionMixin(object):
             return np.bool
         elif type(dt) == TimestampType:
             return np.datetime64
+        elif type(dt) == TimestampNTZType:
+            return np.datetime64
         else:
             return None
 
@@ -354,6 +356,8 @@ class SparkConversionMixin(object):
 
         if timezone is not None:
             from pyspark.sql.pandas.types import 
_check_series_convert_timestamps_tz_local
+            from pandas.core.dtypes.common import is_datetime64tz_dtype
+
             copied = False
             if isinstance(schema, StructType):
                 for field in schema:
@@ -368,8 +372,11 @@ class SparkConversionMixin(object):
                                 copied = True
                             pdf[field.name] = s
             else:
+                should_localize = not self._is_timestamp_ntz_preferred()
                 for column, series in pdf.iteritems():
-                    s = _check_series_convert_timestamps_tz_local(series, 
timezone)
+                    s = series
+                    if should_localize and is_datetime64tz_dtype(s.dtype) and 
s.dt.tz is not None:
+                        s = _check_series_convert_timestamps_tz_local(series, 
timezone)
                     if s is not series:
                         if not copied:
                             # Copy once if the series is modified to prevent 
the original
@@ -448,8 +455,12 @@ class SparkConversionMixin(object):
         if isinstance(schema, (list, tuple)):
             arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
             struct = StructType()
+            prefer_timestamp_ntz = self._is_timestamp_ntz_preferred()
             for name, field in zip(schema, arrow_schema):
-                struct.add(name, from_arrow_type(field.type), 
nullable=field.nullable)
+                struct.add(
+                    name,
+                    from_arrow_type(field.type, prefer_timestamp_ntz),
+                    nullable=field.nullable)
             schema = struct
 
         # Determine arrow types to coerce data when creating batches
diff --git a/python/pyspark/sql/pandas/conversion.pyi 
b/python/pyspark/sql/pandas/conversion.pyi
index 031852f..87637722 100644
--- a/python/pyspark/sql/pandas/conversion.pyi
+++ b/python/pyspark/sql/pandas/conversion.pyi
@@ -38,6 +38,7 @@ from pyspark.sql.types import (  # noqa: F401
     ShortType as ShortType,
     StructType as StructType,
     TimestampType as TimestampType,
+    TimestampNTZType as TimestampNTZType,
 )
 from pyspark.traceback_utils import SCCallSiteSync as SCCallSiteSync  # noqa: 
F401
 
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 2dcfdc1..268b766 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -126,7 +126,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         # datetime64[ns] type handling.
         s = arrow_column.to_pandas(date_as_object=True)
 
-        if pyarrow.types.is_timestamp(arrow_column.type):
+        if pyarrow.types.is_timestamp(arrow_column.type) and 
arrow_column.type.tz is not None:
             return _check_series_localize_timestamps(s, self._timezone)
         elif pyarrow.types.is_map(arrow_column.type):
             return _convert_map_items_to_dict(s)
@@ -162,7 +162,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         def create_array(s, t):
             mask = s.isnull()
             # Ensure timestamp series are in expected form for Spark internal 
representation
-            if t is not None and pa.types.is_timestamp(t):
+            if t is not None and pa.types.is_timestamp(t) and t.tz is not None:
                 s = _check_series_convert_timestamps_internal(s, 
self._timezone)
             elif t is not None and pa.types.is_map(t):
                 s = _convert_dict_to_map_items(s)
diff --git a/python/pyspark/sql/pandas/types.py 
b/python/pyspark/sql/pandas/types.py
index 489b466..ceb71a3 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -22,7 +22,7 @@ pandas instances during the type conversion.
 
 from pyspark.sql.types import BooleanType, ByteType, ShortType, IntegerType, 
LongType, \
     FloatType, DoubleType, DecimalType, StringType, BinaryType, DateType, 
TimestampType, \
-    ArrayType, MapType, StructType, StructField, NullType
+    TimestampNTZType, ArrayType, MapType, StructType, StructField, NullType
 
 
 def to_arrow_type(dt):
@@ -55,6 +55,8 @@ def to_arrow_type(dt):
     elif type(dt) == TimestampType:
         # Timestamps should be in UTC, JVM Arrow timestamps require a timezone 
to be read
         arrow_type = pa.timestamp('us', tz='UTC')
+    elif type(dt) == TimestampNTZType:
+        arrow_type = pa.timestamp('us', tz=None)
     elif type(dt) == ArrayType:
         if type(dt.elementType) in [StructType, TimestampType]:
             raise TypeError("Unsupported type in conversion to Arrow: " + 
str(dt))
@@ -88,7 +90,7 @@ def to_arrow_schema(schema):
     return pa.schema(fields)
 
 
-def from_arrow_type(at):
+def from_arrow_type(at, prefer_timestamp_ntz=False):
     """ Convert pyarrow type to Spark data type.
     """
     from distutils.version import LooseVersion
@@ -116,6 +118,8 @@ def from_arrow_type(at):
         spark_type = BinaryType()
     elif types.is_date32(at):
         spark_type = DateType()
+    elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None:
+        spark_type = TimestampNTZType()
     elif types.is_timestamp(at):
         spark_type = TimestampType()
     elif types.is_list(at):
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 28ae480..12ec649 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -419,6 +419,9 @@ class SparkSession(SparkConversionMixin):
 
         return DataFrame(jdf, self._wrapped)
 
+    def _is_timestamp_ntz_preferred(self):
+        return self._wrapped._conf.timestampType().typeName() == 
"timestamp_ntz"
+
     def _inferSchemaFromList(self, data, names=None):
         """
         Infer schema from list of Row, dict, or tuple.
@@ -437,7 +440,9 @@ class SparkSession(SparkConversionMixin):
         if not data:
             raise ValueError("can not infer schema from empty dataset")
         infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct()
-        schema = reduce(_merge_type, (_infer_schema(row, names, 
infer_dict_as_struct)
+        prefer_timestamp_ntz = self._is_timestamp_ntz_preferred()
+        schema = reduce(_merge_type, (
+            _infer_schema(row, names, infer_dict_as_struct, 
prefer_timestamp_ntz)
                         for row in data))
         if _has_nulltype(schema):
             raise ValueError("Some of types cannot be determined after 
inferring")
@@ -465,12 +470,18 @@ class SparkSession(SparkConversionMixin):
                              "can not infer schema")
 
         infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct()
+        prefer_timestamp_ntz = self._is_timestamp_ntz_preferred()
         if samplingRatio is None:
-            schema = _infer_schema(first, names=names, 
infer_dict_as_struct=infer_dict_as_struct)
+            schema = _infer_schema(
+                first,
+                names=names,
+                infer_dict_as_struct=infer_dict_as_struct,
+                prefer_timestamp_ntz=prefer_timestamp_ntz)
             if _has_nulltype(schema):
                 for row in rdd.take(100)[1:]:
                     schema = _merge_type(schema, _infer_schema(
-                        row, names=names, 
infer_dict_as_struct=infer_dict_as_struct))
+                        row, names=names, 
infer_dict_as_struct=infer_dict_as_struct,
+                        prefer_timestamp_ntz=prefer_timestamp_ntz))
                     if not _has_nulltype(schema):
                         break
                 else:
@@ -480,7 +491,8 @@ class SparkSession(SparkConversionMixin):
             if samplingRatio < 0.99:
                 rdd = rdd.sample(False, float(samplingRatio))
             schema = rdd.map(lambda row: _infer_schema(
-                row, names, 
infer_dict_as_struct=infer_dict_as_struct)).reduce(_merge_type)
+                row, names, infer_dict_as_struct=infer_dict_as_struct,
+                prefer_timestamp_ntz=prefer_timestamp_ntz)).reduce(_merge_type)
         return schema
 
     def _createFromRDD(self, rdd, schema, samplingRatio):
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index cca9ec4..e7fb590 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -27,8 +27,8 @@ from pyspark import SparkContext, SparkConf
 from pyspark.sql import Row, SparkSession
 from pyspark.sql.functions import rand, udf
 from pyspark.sql.types import StructType, StringType, IntegerType, LongType, \
-    FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, 
StructField, \
-    ArrayType, NullType
+    FloatType, DoubleType, DecimalType, DateType, TimestampType, 
TimestampNTZType, \
+    BinaryType, StructField, ArrayType, NullType
 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, 
have_pyarrow, \
     pandas_requirement_message, pyarrow_requirement_message
 from pyspark.testing.utils import QuietTest
@@ -167,6 +167,18 @@ class ArrowTests(ReusedSQLTestCase):
         assert_frame_equal(expected, pdf)
         assert_frame_equal(expected, pdf_arrow)
 
+    def test_create_data_frame_to_pandas_timestamp_ntz(self):
+        # SPARK-36626: Test TimestampNTZ in createDataFrame and toPandas
+        with self.sql_conf({"spark.sql.session.timeZone": 
"America/Los_Angeles"}):
+            origin = pd.DataFrame({"a": [datetime.datetime(2012, 2, 2, 2, 2, 
2)]})
+            df = self.spark.createDataFrame(
+                origin, schema=StructType([StructField("a", 
TimestampNTZType(), True)]))
+            df.selectExpr("assert_true('2012-02-02 02:02:02' == CAST(a AS 
STRING))").collect()
+
+            pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
+            assert_frame_equal(origin, pdf)
+            assert_frame_equal(pdf, pdf_arrow)
+
     def test_toPandas_respect_session_timezone(self):
         df = self.spark.createDataFrame(self.data, schema=self.schema)
 
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 4c38b27..32a0e65 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -25,7 +25,7 @@ import unittest
 from pyspark.sql import SparkSession, Row
 from pyspark.sql.functions import col, lit, count, sum, mean
 from pyspark.sql.types import StringType, IntegerType, DoubleType, StructType, 
StructField, \
-    BooleanType, DateType, TimestampType, FloatType
+    BooleanType, DateType, TimestampType, TimestampNTZType, FloatType
 from pyspark.sql.utils import AnalysisException, IllegalArgumentException
 from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils, 
have_pyarrow, have_pandas, \
     pandas_requirement_message, pyarrow_requirement_message
@@ -575,12 +575,16 @@ class DataFrameTests(ReusedSQLTestCase):
         from datetime import datetime, date
         schema = StructType().add("a", IntegerType()).add("b", StringType())\
                              .add("c", BooleanType()).add("d", FloatType())\
-                             .add("dt", DateType()).add("ts", TimestampType())
+                             .add("dt", DateType()).add("ts", TimestampType())\
+                             .add("ts_ntz", TimestampNTZType())
         data = [
-            (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 
1)),
-            (2, "foo", True, 5.0, None, None),
-            (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 
3, 3)),
-            (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 
4, 4)),
+            (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 
1),
+             datetime(1969, 1, 1, 1, 1, 1)),
+            (2, "foo", True, 5.0, None, None, None),
+            (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 
3, 3),
+             datetime(2012, 3, 3, 3, 3, 3)),
+            (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 
4, 4),
+             datetime(2100, 4, 4, 4, 4, 4)),
         ]
         df = self.spark.createDataFrame(data, schema)
         return df.toPandas()
@@ -596,6 +600,7 @@ class DataFrameTests(ReusedSQLTestCase):
         self.assertEqual(types[3], np.float32)
         self.assertEqual(types[4], np.object)  # datetime.date
         self.assertEqual(types[5], 'datetime64[ns]')
+        self.assertEqual(types[6], 'datetime64[ns]')
 
     @unittest.skipIf(not have_pandas, pandas_requirement_message)  # type: 
ignore
     def test_to_pandas_with_duplicated_column_names(self):
@@ -662,7 +667,8 @@ class DataFrameTests(ReusedSQLTestCase):
             CAST(0 AS DOUBLE) AS double,
             CAST(1 AS BOOLEAN) AS boolean,
             CAST('foo' AS STRING) AS string,
-            CAST('2019-01-01' AS TIMESTAMP) AS timestamp
+            CAST('2019-01-01' AS TIMESTAMP) AS timestamp,
+            CAST('2019-01-01' AS TIMESTAMP_NTZ) AS timestamp_ntz
             """
             dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes
             dtypes_when_empty_df = 
self.spark.sql(sql).filter("False").toPandas().dtypes
@@ -682,7 +688,8 @@ class DataFrameTests(ReusedSQLTestCase):
             CAST(NULL AS DOUBLE) AS double,
             CAST(NULL AS BOOLEAN) AS boolean,
             CAST(NULL AS STRING) AS string,
-            CAST(NULL AS TIMESTAMP) AS timestamp
+            CAST(NULL AS TIMESTAMP) AS timestamp,
+            CAST(NULL AS TIMESTAMP_NTZ) AS timestamp_ntz
             """
             pdf = self.spark.sql(sql).toPandas()
             types = pdf.dtypes
@@ -695,6 +702,7 @@ class DataFrameTests(ReusedSQLTestCase):
             self.assertEqual(types[6], np.object)
             self.assertEqual(types[7], np.object)
             self.assertTrue(np.can_cast(np.datetime64, types[8]))
+            self.assertTrue(np.can_cast(np.datetime64, types[9]))
 
     @unittest.skipIf(not have_pandas, pandas_requirement_message)  # type: 
ignore
     def test_to_pandas_from_mixed_dataframe(self):
@@ -710,9 +718,10 @@ class DataFrameTests(ReusedSQLTestCase):
             CAST(col6 AS DOUBLE) AS double,
             CAST(col7 AS BOOLEAN) AS boolean,
             CAST(col8 AS STRING) AS string,
-            timestamp_seconds(col9) AS timestamp
-            FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1),
-                        (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL)
+            timestamp_seconds(col9) AS timestamp,
+            timestamp_seconds(col10) AS timestamp_ntz
+            FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
+                        (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, 
NULL)
             """
             pdf_with_some_nulls = self.spark.sql(sql).toPandas()
             pdf_with_only_nulls = self.spark.sql(sql).filter('tinyint is 
null').toPandas()
@@ -738,6 +747,9 @@ class DataFrameTests(ReusedSQLTestCase):
         df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp")
         self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
         self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
+        df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp_ntz")
+        self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampNTZType))
+        self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
 
     @unittest.skipIf(have_pandas, "Required Pandas was found.")
     def test_create_dataframe_required_pandas_not_found(self):
diff --git a/python/pyspark/sql/tests/test_pandas_udf.py 
b/python/pyspark/sql/tests/test_pandas_udf.py
index 975eb468..9ebc943 100644
--- a/python/pyspark/sql/tests/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/test_pandas_udf.py
@@ -16,6 +16,7 @@
 #
 
 import unittest
+import datetime
 
 from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
 from pyspark.sql.types import DoubleType, StructType, StructField, LongType
@@ -239,6 +240,23 @@ class PandasUDFTests(ReusedSQLTestCase):
         with 
self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
             df.withColumn('udf', udf('id')).collect()
 
+    def test_pandas_udf_timestamp_ntz(self):
+        # SPARK-36626: Test TimestampNTZ in pandas UDF
+        @pandas_udf(returnType="timestamp_ntz")
+        def noop(s):
+            assert s.iloc[0] == datetime.datetime(1970, 1, 1, 0, 0)
+            return s
+
+        with self.sql_conf({"spark.sql.session.timeZone": "Asia/Hong_Kong"}):
+            df = (self.spark
+                  .createDataFrame(
+                      [(datetime.datetime(1970, 1, 1, 0, 0),)], schema="dt 
timestamp_ntz")
+                  .select(noop("dt").alias("dt")))
+
+            df.selectExpr("assert_true('1970-01-01 00:00:00' == CAST(dt AS 
STRING))").collect()
+            self.assertEqual(df.schema[0].dataType.typeName(), "timestamp_ntz")
+            self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 
0))
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.test_pandas_udf import *  # noqa: F401
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 8bdc837..1dbddf7 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -29,8 +29,8 @@ from pyspark.sql.functions import col
 from pyspark.sql.udf import UserDefinedFunction
 from pyspark.sql.utils import AnalysisException
 from pyspark.sql.types import ByteType, ShortType, IntegerType, FloatType, 
DateType, \
-    TimestampType, MapType, StringType, StructType, StructField, ArrayType, 
DoubleType, LongType, \
-    DecimalType, BinaryType, BooleanType, NullType
+    TimestampType, MapType, StringType, StructType, StructField,\
+    ArrayType, DoubleType, LongType, DecimalType, BinaryType, BooleanType, 
NullType
 from pyspark.sql.types import (  # type: ignore
     _array_signed_int_typecode_ctype_mappings, _array_type_mappings,
     _array_unsigned_int_typecode_ctype_mappings, _infer_type, 
_make_type_verifier, _merge_type
@@ -175,6 +175,18 @@ class TypesTests(ReusedSQLTestCase):
         ]
         self.assertEqual(actual, expected)
 
+        with self.sql_conf({"spark.sql.timestampType": "TIMESTAMP_NTZ"}):
+            with self.sql_conf({"spark.sql.session.timeZone": 
"America/Sao_Paulo"}):
+                df = self.spark.createDataFrame([(datetime.datetime(1970, 1, 
1, 0, 0),)])
+                self.assertEqual(list(df.schema)[0].dataType.simpleString(), 
"timestamp_ntz")
+                self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 
0, 0))
+
+            df = self.spark.createDataFrame([
+                (datetime.datetime(1970, 1, 1, 0, 0),),
+                (datetime.datetime(1970, 1, 1, 0, 0, 
tzinfo=datetime.timezone.utc),)
+            ])
+            self.assertEqual(list(df.schema)[0].dataType.simpleString(), 
"timestamp")
+
     def test_infer_schema_not_enough_names(self):
         df = self.spark.createDataFrame([["a", "b"]], ["col1"])
         self.assertEqual(df.columns, ['col1', '_2'])
diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index 0d13361..98d193f 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -20,13 +20,14 @@ import pydoc
 import shutil
 import tempfile
 import unittest
+import datetime
 
 from pyspark import SparkContext
 from pyspark.sql import SparkSession, Column, Row
 from pyspark.sql.functions import udf
 from pyspark.sql.udf import UserDefinedFunction
 from pyspark.sql.types import StringType, IntegerType, BooleanType, 
DoubleType, LongType, \
-    ArrayType, StructType, StructField
+    ArrayType, StructType, StructField, TimestampNTZType
 from pyspark.sql.utils import AnalysisException
 from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, 
test_not_compiled_message
 from pyspark.testing.utils import QuietTest
@@ -552,6 +553,23 @@ class UDFTests(ReusedSQLTestCase):
         self.assertEqual(f, f_.func)
         self.assertEqual(return_type, f_.returnType)
 
+    def test_udf_timestamp_ntz(self):
+        # SPARK-36626: Test TimestampNTZ in Python UDF
+        @udf(TimestampNTZType())
+        def noop(x):
+            assert x == datetime.datetime(1970, 1, 1, 0, 0)
+            return x
+
+        with self.sql_conf({"spark.sql.session.timeZone": "Pacific/Honolulu"}):
+            df = (self.spark
+                  .createDataFrame(
+                      [(datetime.datetime(1970, 1, 1, 0, 0),)], schema="dt 
timestamp_ntz")
+                  .select(noop("dt").alias("dt")))
+
+            df.selectExpr("assert_true('1970-01-01 00:00:00' == CAST(dt AS 
STRING))").collect()
+            self.assertEqual(df.schema[0].dataType.typeName(), "timestamp_ntz")
+            self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 
0))
+
     def test_nonparam_udf_with_aggregate(self):
         import pyspark.sql.functions as f
 
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 13faf47..6cb8aec 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -34,8 +34,9 @@ from pyspark.serializers import CloudPickleSerializer
 
 __all__ = [
     "DataType", "NullType", "StringType", "BinaryType", "BooleanType", 
"DateType",
-    "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", 
"IntegerType",
-    "LongType", "ShortType", "ArrayType", "MapType", "StructField", 
"StructType"]
+    "TimestampType", "TimestampNTZType", "DecimalType", "DoubleType", 
"FloatType",
+    "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType",
+    "StructField", "StructType"]
 
 
 class DataType(object):
@@ -188,6 +189,29 @@ class TimestampType(AtomicType, 
metaclass=DataTypeSingleton):
             return datetime.datetime.fromtimestamp(ts // 
1000000).replace(microsecond=ts % 1000000)
 
 
+class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton):
+    """Timestamp (datetime.datetime) data type without timezone information.
+    """
+
+    def needConversion(self):
+        return True
+
+    @classmethod
+    def typeName(cls):
+        return 'timestamp_ntz'
+
+    def toInternal(self, dt):
+        if dt is not None:
+            seconds = calendar.timegm(dt.timetuple())
+            return int(seconds) * 1000000 + dt.microsecond
+
+    def fromInternal(self, ts):
+        if ts is not None:
+            # using int to avoid precision loss in float
+            return datetime.datetime.utcfromtimestamp(
+                ts // 1000000).replace(microsecond=ts % 1000000)
+
+
 class DecimalType(FractionalType):
     """Decimal (decimal.Decimal) data type.
 
@@ -767,7 +791,8 @@ class UserDefinedType(DataType):
 
 
 _atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, 
DoubleType,
-                 ByteType, ShortType, IntegerType, LongType, DateType, 
TimestampType, NullType]
+                 ByteType, ShortType, IntegerType, LongType, DateType, 
TimestampType,
+                 TimestampNTZType, NullType]
 _all_atomic_types = dict((t.typeName(), t) for t in _atomic_types)
 _all_complex_types = dict((v.typeName(), v)
                           for v in [ArrayType, MapType, StructType])
@@ -901,6 +926,8 @@ def _parse_datatype_json_value(json_value):
             return _all_atomic_types[json_value]()
         elif json_value == 'decimal':
             return DecimalType()
+        elif json_value == 'timestamp_ntz':
+            return TimestampNTZType()
         elif _FIXED_DECIMAL.match(json_value):
             m = _FIXED_DECIMAL.match(json_value)
             return DecimalType(int(m.group(1)), int(m.group(2)))
@@ -926,8 +953,8 @@ _type_mappings = {
     bytearray: BinaryType,
     decimal.Decimal: DecimalType,
     datetime.date: DateType,
-    datetime.datetime: TimestampType,
-    datetime.time: TimestampType,
+    datetime.datetime: TimestampType,  # can be TimestampNTZType
+    datetime.time: TimestampType,  # can be TimestampNTZType
     bytes: BinaryType,
 }
 
@@ -1005,7 +1032,7 @@ if sys.version_info[0] < 4:
     _array_type_mappings['u'] = StringType
 
 
-def _infer_type(obj, infer_dict_as_struct=False):
+def _infer_type(obj, infer_dict_as_struct=False, prefer_timestamp_ntz=False):
     """Infer the DataType from obj
     """
     if obj is None:
@@ -1018,6 +1045,8 @@ def _infer_type(obj, infer_dict_as_struct=False):
     if dataType is DecimalType:
         # the precision and scale of `obj` may be different from row to row.
         return DecimalType(38, 18)
+    if dataType is TimestampType and prefer_timestamp_ntz and obj.tzname() is 
None:
+        return TimestampNTZType()
     elif dataType is not None:
         return dataType()
 
@@ -1026,18 +1055,21 @@ def _infer_type(obj, infer_dict_as_struct=False):
             struct = StructType()
             for key, value in obj.items():
                 if key is not None and value is not None:
-                    struct.add(key, _infer_type(value, infer_dict_as_struct), 
True)
+                    struct.add(
+                        key, _infer_type(value, infer_dict_as_struct, 
prefer_timestamp_ntz), True)
             return struct
         else:
             for key, value in obj.items():
                 if key is not None and value is not None:
-                    return MapType(_infer_type(key, infer_dict_as_struct),
-                                   _infer_type(value, infer_dict_as_struct), 
True)
+                    return MapType(
+                        _infer_type(key, infer_dict_as_struct, 
prefer_timestamp_ntz),
+                        _infer_type(value, infer_dict_as_struct, 
prefer_timestamp_ntz), True)
             return MapType(NullType(), NullType(), True)
     elif isinstance(obj, list):
         for v in obj:
             if v is not None:
-                return ArrayType(_infer_type(obj[0], infer_dict_as_struct), 
True)
+                return ArrayType(
+                    _infer_type(obj[0], infer_dict_as_struct, 
prefer_timestamp_ntz), True)
         return ArrayType(NullType(), True)
     elif isinstance(obj, array):
         if obj.typecode in _array_type_mappings:
@@ -1051,7 +1083,7 @@ def _infer_type(obj, infer_dict_as_struct=False):
             raise TypeError("not supported type: %s" % type(obj))
 
 
-def _infer_schema(row, names=None, infer_dict_as_struct=False):
+def _infer_schema(row, names=None, infer_dict_as_struct=False, 
prefer_timestamp_ntz=False):
     """Infer the schema from dict/namedtuple/object"""
     if isinstance(row, dict):
         items = sorted(row.items())
@@ -1077,7 +1109,8 @@ def _infer_schema(row, names=None, 
infer_dict_as_struct=False):
     fields = []
     for k, v in items:
         try:
-            fields.append(StructField(k, _infer_type(v, infer_dict_as_struct), 
True))
+            fields.append(StructField(
+                k, _infer_type(v, infer_dict_as_struct, prefer_timestamp_ntz), 
True))
         except TypeError as e:
             raise TypeError("Unable to infer the type of the field 
{}.".format(k)) from e
     return StructType(fields)
@@ -1107,6 +1140,10 @@ def _merge_type(a, b, name=None):
         return b
     elif isinstance(b, NullType):
         return a
+    elif isinstance(a, TimestampType) and isinstance(b, TimestampNTZType):
+        return a
+    elif isinstance(a, TimestampNTZType) and isinstance(b, TimestampType):
+        return b
     elif type(a) is not type(b):
         # TODO: type cast (such as int -> long)
         raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), 
type(b))))
@@ -1211,6 +1248,7 @@ _acceptable_types = {
     BinaryType: (bytearray, bytes),
     DateType: (datetime.date, datetime.datetime),
     TimestampType: (datetime.datetime,),
+    TimestampNTZType: (datetime.datetime,),
     ArrayType: (list, tuple, array),
     MapType: (dict,),
     StructType: (tuple, list, dict),
diff --git a/python/pyspark/sql/types.pyi b/python/pyspark/sql/types.pyi
index 3adf823..58c646f 100644
--- a/python/pyspark/sql/types.pyi
+++ b/python/pyspark/sql/types.pyi
@@ -60,6 +60,11 @@ class TimestampType(AtomicType, metaclass=DataTypeSingleton):
     def toInternal(self, dt: datetime.datetime) -> int: ...
     def fromInternal(self, ts: int) -> datetime.datetime: ...
 
+class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton):
+    def needConversion(self) -> bool: ...
+    def toInternal(self, dt: datetime.datetime) -> int: ...
+    def fromInternal(self, ts: int) -> datetime.datetime: ...
+
 class DecimalType(FractionalType):
     precision: int
     scale: int
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index 4885f63..fe71319 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -35,7 +35,7 @@ import org.apache.spark.unsafe.types.UTF8String
 object EvaluatePython {
 
   def needConversionInPython(dt: DataType): Boolean = dt match {
-    case DateType | TimestampType => true
+    case DateType | TimestampType | TimestampNTZType => true
     case _: StructType => true
     case _: UserDefinedType[_] => true
     case ArrayType(elementType, _) => needConversionInPython(elementType)
@@ -137,7 +137,7 @@ object EvaluatePython {
       case c: Int => c
     }
 
-    case TimestampType => (obj: Any) => nullSafeConvert(obj) {
+    case TimestampType | TimestampNTZType => (obj: Any) => 
nullSafeConvert(obj) {
       case c: Long => c
       // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
       case c: Int => c.toLong

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to