This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new e34012b1be6 [SPARK-42023][SPARK-42024][CONNECT][PYTHON] Make `createDataFrame` support `AtomicType -> StringType` coercion e34012b1be6 is described below commit e34012b1be600b48b18f820c2d8e0836ac0dfc6e Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Jan 31 15:27:28 2023 +0900 [SPARK-42023][SPARK-42024][CONNECT][PYTHON] Make `createDataFrame` support `AtomicType -> StringType` coercion ### What changes were proposed in this pull request? Make `createDataFrame` support `AtomicType -> StringType` coercion ### Why are the changes needed? to be consistent with PySpark, this feature was added in https://github.com/apache/spark/commit/51b04406028e14fbe1986f6a3ffc67853bd82935 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added UT and enabled UT Closes #39818 from zhengruifeng/connect_create_df_corse. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit d0f5f1dc694b1aabe97344c75e15ec061df8758a) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/conversion.py | 38 +++++++++++++++++++++- .../sql/tests/connect/test_connect_basic.py | 11 +++++++ .../pyspark/sql/tests/connect/test_parity_types.py | 10 ------ 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index db647a9733a..51a6002eb60 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -15,6 +15,7 @@ # limitations under the License. # +import array import datetime import decimal @@ -32,6 +33,7 @@ from pyspark.sql.types import ( BinaryType, NullType, DecimalType, + StringType, ) from pyspark.sql.connect.types import to_arrow_schema @@ -71,6 +73,9 @@ class LocalDataToArrowConversion: elif isinstance(dataType, DecimalType): # Convert Decimal('NaN') to None return True + elif isinstance(dataType, StringType): + # Coercion to StringType is allowed + return True else: return False @@ -127,7 +132,7 @@ class LocalDataToArrowConversion: if value is None: return None else: - assert isinstance(value, list) + assert isinstance(value, (list, array.array)) return [element_conv(v) for v in value] return convert_array @@ -184,6 +189,37 @@ class LocalDataToArrowConversion: return convert_decimal + elif isinstance(dataType, StringType): + + def convert_string(value: Any) -> Any: + if value is None: + return None + else: + # only atomic types are supported + assert isinstance( + value, + ( + bool, + int, + float, + str, + bytes, + bytearray, + decimal.Decimal, + datetime.date, + datetime.datetime, + datetime.timedelta, + ), + ) + if isinstance(value, bool): + # To match the PySpark which convert bool to string in + # the JVM side (python.EvaluatePython.makeFromJava) + return str(value).lower() + else: + return str(value) + + return convert_string + else: return lambda value: value diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 061591e742c..d51de331f7a 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -706,6 +706,17 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assertEqual(cdf.schema, sdf.schema) self.assertEqual(cdf.collect(), sdf.collect()) + def test_create_dataframe_with_coercion(self): + data1 = [[1.33, 1], ["2.1", 1]] + data2 = [[True, 1], ["false", 1]] + + for data in [data1, data2]: + cdf = self.connect.createDataFrame(data, ["a", "b"]) + sdf = self.spark.createDataFrame(data, ["a", "b"]) + + self.assertEqual(cdf.schema, sdf.schema) + self.assertEqual(cdf.collect(), sdf.collect()) + def test_nested_type_create_from_rows(self): data1 = [Row(a=1, b=Row(c=2, d=Row(e=3, f=Row(g=4, h=Row(i=5)))))] # root diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py index d40b3f1a5f2..025e64f2bf0 100644 --- a/python/pyspark/sql/tests/connect/test_parity_types.py +++ b/python/pyspark/sql/tests/connect/test_parity_types.py @@ -109,16 +109,6 @@ class TypesParityTests(TypesTestsMixin, ReusedConnectTestCase): def test_infer_schema_to_local(self): super().test_infer_schema_to_local() - # TODO(SPARK-42023): createDataFrame should corse types of string false to bool false - @unittest.skip("Fails in Spark Connect, should enable.") - def test_infer_schema_upcast_boolean_to_string(self): - super().test_infer_schema_upcast_boolean_to_string() - - # TODO(SPARK-42024): createDataFrame should corse types of string float to float - @unittest.skip("Fails in Spark Connect, should enable.") - def test_infer_schema_upcast_float_to_string(self): - super().test_infer_schema_upcast_float_to_string() - @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") def test_infer_schema_upcast_int_to_string(self): super().test_infer_schema_upcast_int_to_string() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org