This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 61035129a35 [SPARK-42875][CONNECT][PYTHON] Fix toPandas to handle timezone and map types properly 61035129a35 is described below commit 61035129a354d0b31c66908106238b12b1f2f7b0 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Tue Mar 21 09:43:11 2023 +0800 [SPARK-42875][CONNECT][PYTHON] Fix toPandas to handle timezone and map types properly ### What changes were proposed in this pull request? Fix `DataFrame.toPandas()` to handle timezone and map types properly. ### Why are the changes needed? Currently `DataFrame.toPandas()` doesn't handle timezone for timestamp type, and map types properly. For example: ```py >>> schema = StructType().add("ts", TimestampType()) >>> spark.createDataFrame([(datetime(1969, 1, 1, 1, 1, 1),), (datetime(2012, 3, 3, 3, 3, 3),), (datetime(2100, 4, 4, 4, 4, 4),)], schema).toPandas() ts 0 1969-01-01 01:01:01-08:00 1 2012-03-03 03:03:03-08:00 2 2100-04-04 03:04:04-08:00 ``` which should be: ```py ts 0 1969-01-01 01:01:01 1 2012-03-03 03:03:03 2 2100-04-04 04:04:04 ``` ### Does this PR introduce _any_ user-facing change? The result of `DataFrame.toPandas()` with timestamp type and map type will be the same as PySpark. ### How was this patch tested? Enabled the related tests. Closes #40497 from ueshin/issues/SPARK-42875/timestamp. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/client.py | 24 ++++++--- .../sql/tests/connect/test_parity_dataframe.py | 20 +++---- python/pyspark/sql/tests/test_dataframe.py | 61 +++++++++++++--------- 3 files changed, 60 insertions(+), 45 deletions(-) diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 090d239fbb4..53fa97372a7 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -71,7 +71,8 @@ from pyspark.sql.connect.expressions import ( CommonInlineUserDefinedFunction, JavaUDF, ) -from pyspark.sql.types import DataType, StructType +from pyspark.sql.pandas.types import _check_series_localize_timestamps, _convert_map_items_to_dict +from pyspark.sql.types import DataType, MapType, StructType, TimestampType from pyspark.rdd import PythonEvalType @@ -637,12 +638,23 @@ class SparkConnectClient(object): logger.info(f"Executing plan {self._proto_to_string(plan)}") req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) - table, _, metrics, observed_metrics, _ = self._execute_and_fetch(req) + table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req) assert table is not None - column_names = table.column_names - table = table.rename_columns([f"col_{i}" for i in range(len(column_names))]) - pdf = table.to_pandas() - pdf.columns = column_names + pdf = table.rename_columns([f"col_{i}" for i in range(len(table.column_names))]).to_pandas() + pdf.columns = table.column_names + + schema = schema or types.from_arrow_schema(table.schema) + assert schema is not None and isinstance(schema, StructType) + + for field, pa_field in zip(schema, table.schema): + if isinstance(field.dataType, TimestampType): + assert pa_field.type.tz is not None + pdf[field.name] = _check_series_localize_timestamps( + pdf[field.name], pa_field.type.tz + ) + elif isinstance(field.dataType, MapType): + pdf[field.name] = _convert_map_items_to_dict(pdf[field.name]) + if len(metrics) > 0: pdf.attrs["metrics"] = metrics if len(observed_metrics) > 0: diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index 31dee6a19d2..ae812b4ca55 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -22,11 +22,6 @@ from pyspark.testing.connectutils import ReusedConnectTestCase class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase): - # TODO(SPARK-41834): Implement SparkSession.conf - @unittest.skip("Fails in Spark Connect, should enable.") - def test_create_dataframe_from_pandas_with_dst(self): - super().test_create_dataframe_from_pandas_with_dst() - @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") def test_help_command(self): super().test_help_command() @@ -87,26 +82,25 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase): def test_to_local_iterator_prefetch(self): super().test_to_local_iterator_prefetch() - # TODO(SPARK-41884): DataFrame `toPandas` parity in return types - @unittest.skip("Fails in Spark Connect, should enable.") - def test_to_pandas(self): - super().test_to_pandas() - def test_to_pandas_for_array_of_struct(self): # Spark Connect's implementation is based on Arrow. super().check_to_pandas_for_array_of_struct(True) - # TODO(SPARK-41834): Implement SparkSession.conf - @unittest.skip("Fails in Spark Connect, should enable.") def test_to_pandas_from_null_dataframe(self): - super().test_to_pandas_from_null_dataframe() + self.check_to_pandas_from_null_dataframe() def test_to_pandas_on_cross_join(self): self.check_to_pandas_on_cross_join() + def test_to_pandas_from_empty_dataframe(self): + self.check_to_pandas_from_empty_dataframe() + def test_to_pandas_with_duplicated_column_names(self): self.check_to_pandas_with_duplicated_column_names() + def test_to_pandas_from_mixed_dataframe(self): + self.check_to_pandas_from_mixed_dataframe() + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index bd2f1cb75b7..cb209f472bf 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -1186,6 +1186,12 @@ class DataFrameTestsMixin: @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_from_empty_dataframe(self): + is_arrow_enabled = [True, False] + for value in is_arrow_enabled: + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): + self.check_to_pandas_from_empty_dataframe() + + def check_to_pandas_from_empty_dataframe(self): # SPARK-29188 test that toPandas() on an empty dataframe has the correct dtypes # SPARK-30537 test that toPandas() on an empty dataframe has the correct dtypes # when arrow is enabled @@ -1204,15 +1210,18 @@ class DataFrameTestsMixin: CAST('2019-01-01' AS TIMESTAMP_NTZ) AS timestamp_ntz, INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval """ + dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes + dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes + self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df)) + + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore + def test_to_pandas_from_null_dataframe(self): is_arrow_enabled = [True, False] for value in is_arrow_enabled: with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): - dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes - dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes - self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df)) + self.check_to_pandas_from_null_dataframe() - @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore - def test_to_pandas_from_null_dataframe(self): + def check_to_pandas_from_null_dataframe(self): # SPARK-29188 test that toPandas() on a dataframe with only nulls has correct dtypes # SPARK-30537 test that toPandas() on a dataframe with only nulls has correct dtypes # using arrow @@ -1231,25 +1240,28 @@ class DataFrameTestsMixin: CAST(NULL AS TIMESTAMP_NTZ) AS timestamp_ntz, INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval """ + pdf = self.spark.sql(sql).toPandas() + types = pdf.dtypes + self.assertEqual(types[0], np.float64) + self.assertEqual(types[1], np.float64) + self.assertEqual(types[2], np.float64) + self.assertEqual(types[3], np.float64) + self.assertEqual(types[4], np.float32) + self.assertEqual(types[5], np.float64) + self.assertEqual(types[6], object) + self.assertEqual(types[7], object) + self.assertTrue(np.can_cast(np.datetime64, types[8])) + self.assertTrue(np.can_cast(np.datetime64, types[9])) + self.assertTrue(np.can_cast(np.timedelta64, types[10])) + + @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore + def test_to_pandas_from_mixed_dataframe(self): is_arrow_enabled = [True, False] for value in is_arrow_enabled: with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): - pdf = self.spark.sql(sql).toPandas() - types = pdf.dtypes - self.assertEqual(types[0], np.float64) - self.assertEqual(types[1], np.float64) - self.assertEqual(types[2], np.float64) - self.assertEqual(types[3], np.float64) - self.assertEqual(types[4], np.float32) - self.assertEqual(types[5], np.float64) - self.assertEqual(types[6], object) - self.assertEqual(types[7], object) - self.assertTrue(np.can_cast(np.datetime64, types[8])) - self.assertTrue(np.can_cast(np.datetime64, types[9])) - self.assertTrue(np.can_cast(np.timedelta64, types[10])) + self.check_to_pandas_from_mixed_dataframe() - @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore - def test_to_pandas_from_mixed_dataframe(self): + def check_to_pandas_from_mixed_dataframe(self): # SPARK-29188 test that toPandas() on a dataframe with some nulls has correct dtypes # SPARK-30537 test that toPandas() on a dataframe with some nulls has correct dtypes # using arrow @@ -1270,12 +1282,9 @@ class DataFrameTestsMixin: FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL) """ - is_arrow_enabled = [True, False] - for value in is_arrow_enabled: - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): - pdf_with_some_nulls = self.spark.sql(sql).toPandas() - pdf_with_only_nulls = self.spark.sql(sql).filter("tinyint is null").toPandas() - self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes)) + pdf_with_some_nulls = self.spark.sql(sql).toPandas() + pdf_with_only_nulls = self.spark.sql(sql).filter("tinyint is null").toPandas() + self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes)) @unittest.skipIf( not have_pandas or not have_pyarrow or pyarrow_version_less_than_minimum("2.0.0"), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org