This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9c5bcac [SPARK-36626][PYTHON] Support TimestampNTZ in createDataFrame/toPandas and Python UDFs 9c5bcac is described below commit 9c5bcac61ee56fbb271e890cc33f9a983612c5b0 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Thu Sep 2 14:00:27 2021 +0900 [SPARK-36626][PYTHON] Support TimestampNTZ in createDataFrame/toPandas and Python UDFs ### What changes were proposed in this pull request? This PR proposes to implement `TimestampNTZType` support in PySpark's `SparkSession.createDataFrame`, `DataFrame.toPandas`, Python UDFs, and pandas UDFs with and without Arrow. ### Why are the changes needed? To complete `TimestampNTZType` support. ### Does this PR introduce _any_ user-facing change? Yes. - Users now can use `TimestampNTZType` type in `SparkSession.createDataFrame`, `DataFrame.toPandas`, Python UDFs, and pandas UDFs with and without Arrow. - If `spark.sql.timestampType` is configured to `TIMESTAMP_NTZ`, PySpark will infer the `datetime` without timezone as `TimestampNTZType`. If it has a timezone, it will be inferred as `TimestampType` in `SparkSession.createDataFrame`. - If `TimestampType` and `TimestampNTZType` conflict during merging inferred schema, `TimestampType` has a higher precedence. - If the type is `TimestampNTZType`, treat this internally as an unknown timezone, and compute w/ UTC (same as JVM side), and avoid localization externally. ### How was this patch tested? Manually tested and unittests were added. Closes #33876 from HyukjinKwon/SPARK-36626. Lead-authored-by: Hyukjin Kwon <gurwls...@apache.org> Co-authored-by: Dominik Gehl <d...@open.ch> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/docs/source/reference/pyspark.sql.rst | 1 + python/pyspark/sql/pandas/conversion.py | 17 ++++-- python/pyspark/sql/pandas/conversion.pyi | 1 + python/pyspark/sql/pandas/serializers.py | 4 +- python/pyspark/sql/pandas/types.py | 8 ++- python/pyspark/sql/session.py | 20 +++++-- python/pyspark/sql/tests/test_arrow.py | 16 +++++- python/pyspark/sql/tests/test_dataframe.py | 34 ++++++++---- python/pyspark/sql/tests/test_pandas_udf.py | 18 +++++++ python/pyspark/sql/tests/test_types.py | 16 +++++- python/pyspark/sql/tests/test_udf.py | 20 ++++++- python/pyspark/sql/types.py | 62 +++++++++++++++++----- python/pyspark/sql/types.pyi | 5 ++ .../sql/execution/python/EvaluatePython.scala | 4 +- 14 files changed, 185 insertions(+), 41 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql.rst b/python/docs/source/reference/pyspark.sql.rst index d72332a..605a150 100644 --- a/python/docs/source/reference/pyspark.sql.rst +++ b/python/docs/source/reference/pyspark.sql.rst @@ -298,6 +298,7 @@ Data Types StringType StructField StructType + TimestampNTZType TimestampType diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 92ef7ce..5454410 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -22,7 +22,7 @@ from pyspark.rdd import _load_from_socket from pyspark.sql.pandas.serializers import ArrowCollectSerializer from pyspark.sql.types import IntegralType from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \ - DoubleType, BooleanType, MapType, TimestampType, StructType, DataType + DoubleType, BooleanType, MapType, TimestampType, TimestampNTZType, StructType, DataType from pyspark.traceback_utils import SCCallSiteSync @@ -238,6 +238,8 @@ class PandasConversionMixin(object): return np.bool elif type(dt) == TimestampType: return np.datetime64 + elif type(dt) == TimestampNTZType: + return np.datetime64 else: return None @@ -354,6 +356,8 @@ class SparkConversionMixin(object): if timezone is not None: from pyspark.sql.pandas.types import _check_series_convert_timestamps_tz_local + from pandas.core.dtypes.common import is_datetime64tz_dtype + copied = False if isinstance(schema, StructType): for field in schema: @@ -368,8 +372,11 @@ class SparkConversionMixin(object): copied = True pdf[field.name] = s else: + should_localize = not self._is_timestamp_ntz_preferred() for column, series in pdf.iteritems(): - s = _check_series_convert_timestamps_tz_local(series, timezone) + s = series + if should_localize and is_datetime64tz_dtype(s.dtype) and s.dt.tz is not None: + s = _check_series_convert_timestamps_tz_local(series, timezone) if s is not series: if not copied: # Copy once if the series is modified to prevent the original @@ -448,8 +455,12 @@ class SparkConversionMixin(object): if isinstance(schema, (list, tuple)): arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False) struct = StructType() + prefer_timestamp_ntz = self._is_timestamp_ntz_preferred() for name, field in zip(schema, arrow_schema): - struct.add(name, from_arrow_type(field.type), nullable=field.nullable) + struct.add( + name, + from_arrow_type(field.type, prefer_timestamp_ntz), + nullable=field.nullable) schema = struct # Determine arrow types to coerce data when creating batches diff --git a/python/pyspark/sql/pandas/conversion.pyi b/python/pyspark/sql/pandas/conversion.pyi index 031852f..87637722 100644 --- a/python/pyspark/sql/pandas/conversion.pyi +++ b/python/pyspark/sql/pandas/conversion.pyi @@ -38,6 +38,7 @@ from pyspark.sql.types import ( # noqa: F401 ShortType as ShortType, StructType as StructType, TimestampType as TimestampType, + TimestampNTZType as TimestampNTZType, ) from pyspark.traceback_utils import SCCallSiteSync as SCCallSiteSync # noqa: F401 diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 2dcfdc1..268b766 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -126,7 +126,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): # datetime64[ns] type handling. s = arrow_column.to_pandas(date_as_object=True) - if pyarrow.types.is_timestamp(arrow_column.type): + if pyarrow.types.is_timestamp(arrow_column.type) and arrow_column.type.tz is not None: return _check_series_localize_timestamps(s, self._timezone) elif pyarrow.types.is_map(arrow_column.type): return _convert_map_items_to_dict(s) @@ -162,7 +162,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): def create_array(s, t): mask = s.isnull() # Ensure timestamp series are in expected form for Spark internal representation - if t is not None and pa.types.is_timestamp(t): + if t is not None and pa.types.is_timestamp(t) and t.tz is not None: s = _check_series_convert_timestamps_internal(s, self._timezone) elif t is not None and pa.types.is_map(t): s = _convert_dict_to_map_items(s) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 489b466..ceb71a3 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -22,7 +22,7 @@ pandas instances during the type conversion. from pyspark.sql.types import BooleanType, ByteType, ShortType, IntegerType, LongType, \ FloatType, DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, \ - ArrayType, MapType, StructType, StructField, NullType + TimestampNTZType, ArrayType, MapType, StructType, StructField, NullType def to_arrow_type(dt): @@ -55,6 +55,8 @@ def to_arrow_type(dt): elif type(dt) == TimestampType: # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') + elif type(dt) == TimestampNTZType: + arrow_type = pa.timestamp('us', tz=None) elif type(dt) == ArrayType: if type(dt.elementType) in [StructType, TimestampType]: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) @@ -88,7 +90,7 @@ def to_arrow_schema(schema): return pa.schema(fields) -def from_arrow_type(at): +def from_arrow_type(at, prefer_timestamp_ntz=False): """ Convert pyarrow type to Spark data type. """ from distutils.version import LooseVersion @@ -116,6 +118,8 @@ def from_arrow_type(at): spark_type = BinaryType() elif types.is_date32(at): spark_type = DateType() + elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None: + spark_type = TimestampNTZType() elif types.is_timestamp(at): spark_type = TimestampType() elif types.is_list(at): diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 28ae480..12ec649 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -419,6 +419,9 @@ class SparkSession(SparkConversionMixin): return DataFrame(jdf, self._wrapped) + def _is_timestamp_ntz_preferred(self): + return self._wrapped._conf.timestampType().typeName() == "timestamp_ntz" + def _inferSchemaFromList(self, data, names=None): """ Infer schema from list of Row, dict, or tuple. @@ -437,7 +440,9 @@ class SparkSession(SparkConversionMixin): if not data: raise ValueError("can not infer schema from empty dataset") infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct() - schema = reduce(_merge_type, (_infer_schema(row, names, infer_dict_as_struct) + prefer_timestamp_ntz = self._is_timestamp_ntz_preferred() + schema = reduce(_merge_type, ( + _infer_schema(row, names, infer_dict_as_struct, prefer_timestamp_ntz) for row in data)) if _has_nulltype(schema): raise ValueError("Some of types cannot be determined after inferring") @@ -465,12 +470,18 @@ class SparkSession(SparkConversionMixin): "can not infer schema") infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct() + prefer_timestamp_ntz = self._is_timestamp_ntz_preferred() if samplingRatio is None: - schema = _infer_schema(first, names=names, infer_dict_as_struct=infer_dict_as_struct) + schema = _infer_schema( + first, + names=names, + infer_dict_as_struct=infer_dict_as_struct, + prefer_timestamp_ntz=prefer_timestamp_ntz) if _has_nulltype(schema): for row in rdd.take(100)[1:]: schema = _merge_type(schema, _infer_schema( - row, names=names, infer_dict_as_struct=infer_dict_as_struct)) + row, names=names, infer_dict_as_struct=infer_dict_as_struct, + prefer_timestamp_ntz=prefer_timestamp_ntz)) if not _has_nulltype(schema): break else: @@ -480,7 +491,8 @@ class SparkSession(SparkConversionMixin): if samplingRatio < 0.99: rdd = rdd.sample(False, float(samplingRatio)) schema = rdd.map(lambda row: _infer_schema( - row, names, infer_dict_as_struct=infer_dict_as_struct)).reduce(_merge_type) + row, names, infer_dict_as_struct=infer_dict_as_struct, + prefer_timestamp_ntz=prefer_timestamp_ntz)).reduce(_merge_type) return schema def _createFromRDD(self, rdd, schema, samplingRatio): diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index cca9ec4..e7fb590 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -27,8 +27,8 @@ from pyspark import SparkContext, SparkConf from pyspark.sql import Row, SparkSession from pyspark.sql.functions import rand, udf from pyspark.sql.types import StructType, StringType, IntegerType, LongType, \ - FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, \ - ArrayType, NullType + FloatType, DoubleType, DecimalType, DateType, TimestampType, TimestampNTZType, \ + BinaryType, StructField, ArrayType, NullType from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message from pyspark.testing.utils import QuietTest @@ -167,6 +167,18 @@ class ArrowTests(ReusedSQLTestCase): assert_frame_equal(expected, pdf) assert_frame_equal(expected, pdf_arrow) + def test_create_data_frame_to_pandas_timestamp_ntz(self): + # SPARK-36626: Test TimestampNTZ in createDataFrame and toPandas + with self.sql_conf({"spark.sql.session.timeZone": "America/Los_Angeles"}): + origin = pd.DataFrame({"a": [datetime.datetime(2012, 2, 2, 2, 2, 2)]}) + df = self.spark.createDataFrame( + origin, schema=StructType([StructField("a", TimestampNTZType(), True)])) + df.selectExpr("assert_true('2012-02-02 02:02:02' == CAST(a AS STRING))").collect() + + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + assert_frame_equal(origin, pdf) + assert_frame_equal(pdf, pdf_arrow) + def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 4c38b27..32a0e65 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -25,7 +25,7 @@ import unittest from pyspark.sql import SparkSession, Row from pyspark.sql.functions import col, lit, count, sum, mean from pyspark.sql.types import StringType, IntegerType, DoubleType, StructType, StructField, \ - BooleanType, DateType, TimestampType, FloatType + BooleanType, DateType, TimestampType, TimestampNTZType, FloatType from pyspark.sql.utils import AnalysisException, IllegalArgumentException from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils, have_pyarrow, have_pandas, \ pandas_requirement_message, pyarrow_requirement_message @@ -575,12 +575,16 @@ class DataFrameTests(ReusedSQLTestCase): from datetime import datetime, date schema = StructType().add("a", IntegerType()).add("b", StringType())\ .add("c", BooleanType()).add("d", FloatType())\ - .add("dt", DateType()).add("ts", TimestampType()) + .add("dt", DateType()).add("ts", TimestampType())\ + .add("ts_ntz", TimestampNTZType()) data = [ - (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - (2, "foo", True, 5.0, None, None), - (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)), - (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)), + (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1), + datetime(1969, 1, 1, 1, 1, 1)), + (2, "foo", True, 5.0, None, None, None), + (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3), + datetime(2012, 3, 3, 3, 3, 3)), + (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4), + datetime(2100, 4, 4, 4, 4, 4)), ] df = self.spark.createDataFrame(data, schema) return df.toPandas() @@ -596,6 +600,7 @@ class DataFrameTests(ReusedSQLTestCase): self.assertEqual(types[3], np.float32) self.assertEqual(types[4], np.object) # datetime.date self.assertEqual(types[5], 'datetime64[ns]') + self.assertEqual(types[6], 'datetime64[ns]') @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_with_duplicated_column_names(self): @@ -662,7 +667,8 @@ class DataFrameTests(ReusedSQLTestCase): CAST(0 AS DOUBLE) AS double, CAST(1 AS BOOLEAN) AS boolean, CAST('foo' AS STRING) AS string, - CAST('2019-01-01' AS TIMESTAMP) AS timestamp + CAST('2019-01-01' AS TIMESTAMP) AS timestamp, + CAST('2019-01-01' AS TIMESTAMP_NTZ) AS timestamp_ntz """ dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes @@ -682,7 +688,8 @@ class DataFrameTests(ReusedSQLTestCase): CAST(NULL AS DOUBLE) AS double, CAST(NULL AS BOOLEAN) AS boolean, CAST(NULL AS STRING) AS string, - CAST(NULL AS TIMESTAMP) AS timestamp + CAST(NULL AS TIMESTAMP) AS timestamp, + CAST(NULL AS TIMESTAMP_NTZ) AS timestamp_ntz """ pdf = self.spark.sql(sql).toPandas() types = pdf.dtypes @@ -695,6 +702,7 @@ class DataFrameTests(ReusedSQLTestCase): self.assertEqual(types[6], np.object) self.assertEqual(types[7], np.object) self.assertTrue(np.can_cast(np.datetime64, types[8])) + self.assertTrue(np.can_cast(np.datetime64, types[9])) @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_from_mixed_dataframe(self): @@ -710,9 +718,10 @@ class DataFrameTests(ReusedSQLTestCase): CAST(col6 AS DOUBLE) AS double, CAST(col7 AS BOOLEAN) AS boolean, CAST(col8 AS STRING) AS string, - timestamp_seconds(col9) AS timestamp - FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1), - (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL) + timestamp_seconds(col9) AS timestamp, + timestamp_seconds(col10) AS timestamp_ntz + FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1), + (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL) """ pdf_with_some_nulls = self.spark.sql(sql).toPandas() pdf_with_only_nulls = self.spark.sql(sql).filter('tinyint is null').toPandas() @@ -738,6 +747,9 @@ class DataFrameTests(ReusedSQLTestCase): df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp") self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) + df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp_ntz") + self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampNTZType)) + self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) @unittest.skipIf(have_pandas, "Required Pandas was found.") def test_create_dataframe_required_pandas_not_found(self): diff --git a/python/pyspark/sql/tests/test_pandas_udf.py b/python/pyspark/sql/tests/test_pandas_udf.py index 975eb468..9ebc943 100644 --- a/python/pyspark/sql/tests/test_pandas_udf.py +++ b/python/pyspark/sql/tests/test_pandas_udf.py @@ -16,6 +16,7 @@ # import unittest +import datetime from pyspark.sql.functions import udf, pandas_udf, PandasUDFType from pyspark.sql.types import DoubleType, StructType, StructField, LongType @@ -239,6 +240,23 @@ class PandasUDFTests(ReusedSQLTestCase): with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}): df.withColumn('udf', udf('id')).collect() + def test_pandas_udf_timestamp_ntz(self): + # SPARK-36626: Test TimestampNTZ in pandas UDF + @pandas_udf(returnType="timestamp_ntz") + def noop(s): + assert s.iloc[0] == datetime.datetime(1970, 1, 1, 0, 0) + return s + + with self.sql_conf({"spark.sql.session.timeZone": "Asia/Hong_Kong"}): + df = (self.spark + .createDataFrame( + [(datetime.datetime(1970, 1, 1, 0, 0),)], schema="dt timestamp_ntz") + .select(noop("dt").alias("dt"))) + + df.selectExpr("assert_true('1970-01-01 00:00:00' == CAST(dt AS STRING))").collect() + self.assertEqual(df.schema[0].dataType.typeName(), "timestamp_ntz") + self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 0)) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_udf import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 8bdc837..1dbddf7 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -29,8 +29,8 @@ from pyspark.sql.functions import col from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.utils import AnalysisException from pyspark.sql.types import ByteType, ShortType, IntegerType, FloatType, DateType, \ - TimestampType, MapType, StringType, StructType, StructField, ArrayType, DoubleType, LongType, \ - DecimalType, BinaryType, BooleanType, NullType + TimestampType, MapType, StringType, StructType, StructField,\ + ArrayType, DoubleType, LongType, DecimalType, BinaryType, BooleanType, NullType from pyspark.sql.types import ( # type: ignore _array_signed_int_typecode_ctype_mappings, _array_type_mappings, _array_unsigned_int_typecode_ctype_mappings, _infer_type, _make_type_verifier, _merge_type @@ -175,6 +175,18 @@ class TypesTests(ReusedSQLTestCase): ] self.assertEqual(actual, expected) + with self.sql_conf({"spark.sql.timestampType": "TIMESTAMP_NTZ"}): + with self.sql_conf({"spark.sql.session.timeZone": "America/Sao_Paulo"}): + df = self.spark.createDataFrame([(datetime.datetime(1970, 1, 1, 0, 0),)]) + self.assertEqual(list(df.schema)[0].dataType.simpleString(), "timestamp_ntz") + self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 0)) + + df = self.spark.createDataFrame([ + (datetime.datetime(1970, 1, 1, 0, 0),), + (datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc),) + ]) + self.assertEqual(list(df.schema)[0].dataType.simpleString(), "timestamp") + def test_infer_schema_not_enough_names(self): df = self.spark.createDataFrame([["a", "b"]], ["col1"]) self.assertEqual(df.columns, ['col1', '_2']) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 0d13361..98d193f 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -20,13 +20,14 @@ import pydoc import shutil import tempfile import unittest +import datetime from pyspark import SparkContext from pyspark.sql import SparkSession, Column, Row from pyspark.sql.functions import udf from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import StringType, IntegerType, BooleanType, DoubleType, LongType, \ - ArrayType, StructType, StructField + ArrayType, StructType, StructField, TimestampNTZType from pyspark.sql.utils import AnalysisException from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message from pyspark.testing.utils import QuietTest @@ -552,6 +553,23 @@ class UDFTests(ReusedSQLTestCase): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + def test_udf_timestamp_ntz(self): + # SPARK-36626: Test TimestampNTZ in Python UDF + @udf(TimestampNTZType()) + def noop(x): + assert x == datetime.datetime(1970, 1, 1, 0, 0) + return x + + with self.sql_conf({"spark.sql.session.timeZone": "Pacific/Honolulu"}): + df = (self.spark + .createDataFrame( + [(datetime.datetime(1970, 1, 1, 0, 0),)], schema="dt timestamp_ntz") + .select(noop("dt").alias("dt"))) + + df.selectExpr("assert_true('1970-01-01 00:00:00' == CAST(dt AS STRING))").collect() + self.assertEqual(df.schema[0].dataType.typeName(), "timestamp_ntz") + self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 0)) + def test_nonparam_udf_with_aggregate(self): import pyspark.sql.functions as f diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 13faf47..6cb8aec 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -34,8 +34,9 @@ from pyspark.serializers import CloudPickleSerializer __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", - "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", - "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"] + "TimestampType", "TimestampNTZType", "DecimalType", "DoubleType", "FloatType", + "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", + "StructField", "StructType"] class DataType(object): @@ -188,6 +189,29 @@ class TimestampType(AtomicType, metaclass=DataTypeSingleton): return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000) +class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton): + """Timestamp (datetime.datetime) data type without timezone information. + """ + + def needConversion(self): + return True + + @classmethod + def typeName(cls): + return 'timestamp_ntz' + + def toInternal(self, dt): + if dt is not None: + seconds = calendar.timegm(dt.timetuple()) + return int(seconds) * 1000000 + dt.microsecond + + def fromInternal(self, ts): + if ts is not None: + # using int to avoid precision loss in float + return datetime.datetime.utcfromtimestamp( + ts // 1000000).replace(microsecond=ts % 1000000) + + class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. @@ -767,7 +791,8 @@ class UserDefinedType(DataType): _atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType, - ByteType, ShortType, IntegerType, LongType, DateType, TimestampType, NullType] + ByteType, ShortType, IntegerType, LongType, DateType, TimestampType, + TimestampNTZType, NullType] _all_atomic_types = dict((t.typeName(), t) for t in _atomic_types) _all_complex_types = dict((v.typeName(), v) for v in [ArrayType, MapType, StructType]) @@ -901,6 +926,8 @@ def _parse_datatype_json_value(json_value): return _all_atomic_types[json_value]() elif json_value == 'decimal': return DecimalType() + elif json_value == 'timestamp_ntz': + return TimestampNTZType() elif _FIXED_DECIMAL.match(json_value): m = _FIXED_DECIMAL.match(json_value) return DecimalType(int(m.group(1)), int(m.group(2))) @@ -926,8 +953,8 @@ _type_mappings = { bytearray: BinaryType, decimal.Decimal: DecimalType, datetime.date: DateType, - datetime.datetime: TimestampType, - datetime.time: TimestampType, + datetime.datetime: TimestampType, # can be TimestampNTZType + datetime.time: TimestampType, # can be TimestampNTZType bytes: BinaryType, } @@ -1005,7 +1032,7 @@ if sys.version_info[0] < 4: _array_type_mappings['u'] = StringType -def _infer_type(obj, infer_dict_as_struct=False): +def _infer_type(obj, infer_dict_as_struct=False, prefer_timestamp_ntz=False): """Infer the DataType from obj """ if obj is None: @@ -1018,6 +1045,8 @@ def _infer_type(obj, infer_dict_as_struct=False): if dataType is DecimalType: # the precision and scale of `obj` may be different from row to row. return DecimalType(38, 18) + if dataType is TimestampType and prefer_timestamp_ntz and obj.tzname() is None: + return TimestampNTZType() elif dataType is not None: return dataType() @@ -1026,18 +1055,21 @@ def _infer_type(obj, infer_dict_as_struct=False): struct = StructType() for key, value in obj.items(): if key is not None and value is not None: - struct.add(key, _infer_type(value, infer_dict_as_struct), True) + struct.add( + key, _infer_type(value, infer_dict_as_struct, prefer_timestamp_ntz), True) return struct else: for key, value in obj.items(): if key is not None and value is not None: - return MapType(_infer_type(key, infer_dict_as_struct), - _infer_type(value, infer_dict_as_struct), True) + return MapType( + _infer_type(key, infer_dict_as_struct, prefer_timestamp_ntz), + _infer_type(value, infer_dict_as_struct, prefer_timestamp_ntz), True) return MapType(NullType(), NullType(), True) elif isinstance(obj, list): for v in obj: if v is not None: - return ArrayType(_infer_type(obj[0], infer_dict_as_struct), True) + return ArrayType( + _infer_type(obj[0], infer_dict_as_struct, prefer_timestamp_ntz), True) return ArrayType(NullType(), True) elif isinstance(obj, array): if obj.typecode in _array_type_mappings: @@ -1051,7 +1083,7 @@ def _infer_type(obj, infer_dict_as_struct=False): raise TypeError("not supported type: %s" % type(obj)) -def _infer_schema(row, names=None, infer_dict_as_struct=False): +def _infer_schema(row, names=None, infer_dict_as_struct=False, prefer_timestamp_ntz=False): """Infer the schema from dict/namedtuple/object""" if isinstance(row, dict): items = sorted(row.items()) @@ -1077,7 +1109,8 @@ def _infer_schema(row, names=None, infer_dict_as_struct=False): fields = [] for k, v in items: try: - fields.append(StructField(k, _infer_type(v, infer_dict_as_struct), True)) + fields.append(StructField( + k, _infer_type(v, infer_dict_as_struct, prefer_timestamp_ntz), True)) except TypeError as e: raise TypeError("Unable to infer the type of the field {}.".format(k)) from e return StructType(fields) @@ -1107,6 +1140,10 @@ def _merge_type(a, b, name=None): return b elif isinstance(b, NullType): return a + elif isinstance(a, TimestampType) and isinstance(b, TimestampNTZType): + return a + elif isinstance(a, TimestampNTZType) and isinstance(b, TimestampType): + return b elif type(a) is not type(b): # TODO: type cast (such as int -> long) raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b)))) @@ -1211,6 +1248,7 @@ _acceptable_types = { BinaryType: (bytearray, bytes), DateType: (datetime.date, datetime.datetime), TimestampType: (datetime.datetime,), + TimestampNTZType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), StructType: (tuple, list, dict), diff --git a/python/pyspark/sql/types.pyi b/python/pyspark/sql/types.pyi index 3adf823..58c646f 100644 --- a/python/pyspark/sql/types.pyi +++ b/python/pyspark/sql/types.pyi @@ -60,6 +60,11 @@ class TimestampType(AtomicType, metaclass=DataTypeSingleton): def toInternal(self, dt: datetime.datetime) -> int: ... def fromInternal(self, ts: int) -> datetime.datetime: ... +class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton): + def needConversion(self) -> bool: ... + def toInternal(self, dt: datetime.datetime) -> int: ... + def fromInternal(self, ts: int) -> datetime.datetime: ... + class DecimalType(FractionalType): precision: int scale: int diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 4885f63..fe71319 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -35,7 +35,7 @@ import org.apache.spark.unsafe.types.UTF8String object EvaluatePython { def needConversionInPython(dt: DataType): Boolean = dt match { - case DateType | TimestampType => true + case DateType | TimestampType | TimestampNTZType => true case _: StructType => true case _: UserDefinedType[_] => true case ArrayType(elementType, _) => needConversionInPython(elementType) @@ -137,7 +137,7 @@ object EvaluatePython { case c: Int => c } - case TimestampType => (obj: Any) => nullSafeConvert(obj) { + case TimestampType | TimestampNTZType => (obj: Any) => nullSafeConvert(obj) { case c: Long => c // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs case c: Int => c.toLong --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org