This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new e34012b1be6 [SPARK-42023][SPARK-42024][CONNECT][PYTHON] Make 
`createDataFrame` support `AtomicType -> StringType` coercion
e34012b1be6 is described below

commit e34012b1be600b48b18f820c2d8e0836ac0dfc6e
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Jan 31 15:27:28 2023 +0900

    [SPARK-42023][SPARK-42024][CONNECT][PYTHON] Make `createDataFrame` support 
`AtomicType -> StringType` coercion
    
    ### What changes were proposed in this pull request?
    Make `createDataFrame` support `AtomicType -> StringType` coercion
    
    ### Why are the changes needed?
    to be consistent with PySpark, this feature was added in 
https://github.com/apache/spark/commit/51b04406028e14fbe1986f6a3ffc67853bd82935
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    added UT and enabled UT
    
    Closes #39818 from zhengruifeng/connect_create_df_corse.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit d0f5f1dc694b1aabe97344c75e15ec061df8758a)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/conversion.py           | 38 +++++++++++++++++++++-
 .../sql/tests/connect/test_connect_basic.py        | 11 +++++++
 .../pyspark/sql/tests/connect/test_parity_types.py | 10 ------
 3 files changed, 48 insertions(+), 11 deletions(-)

diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index db647a9733a..51a6002eb60 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 
+import array
 import datetime
 import decimal
 
@@ -32,6 +33,7 @@ from pyspark.sql.types import (
     BinaryType,
     NullType,
     DecimalType,
+    StringType,
 )
 
 from pyspark.sql.connect.types import to_arrow_schema
@@ -71,6 +73,9 @@ class LocalDataToArrowConversion:
         elif isinstance(dataType, DecimalType):
             # Convert Decimal('NaN') to None
             return True
+        elif isinstance(dataType, StringType):
+            # Coercion to StringType is allowed
+            return True
         else:
             return False
 
@@ -127,7 +132,7 @@ class LocalDataToArrowConversion:
                 if value is None:
                     return None
                 else:
-                    assert isinstance(value, list)
+                    assert isinstance(value, (list, array.array))
                     return [element_conv(v) for v in value]
 
             return convert_array
@@ -184,6 +189,37 @@ class LocalDataToArrowConversion:
 
             return convert_decimal
 
+        elif isinstance(dataType, StringType):
+
+            def convert_string(value: Any) -> Any:
+                if value is None:
+                    return None
+                else:
+                    # only atomic types are supported
+                    assert isinstance(
+                        value,
+                        (
+                            bool,
+                            int,
+                            float,
+                            str,
+                            bytes,
+                            bytearray,
+                            decimal.Decimal,
+                            datetime.date,
+                            datetime.datetime,
+                            datetime.timedelta,
+                        ),
+                    )
+                    if isinstance(value, bool):
+                        # To match the PySpark which convert bool to string in
+                        # the JVM side (python.EvaluatePython.makeFromJava)
+                        return str(value).lower()
+                    else:
+                        return str(value)
+
+            return convert_string
+
         else:
 
             return lambda value: value
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 061591e742c..d51de331f7a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -706,6 +706,17 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         self.assertEqual(cdf.schema, sdf.schema)
         self.assertEqual(cdf.collect(), sdf.collect())
 
+    def test_create_dataframe_with_coercion(self):
+        data1 = [[1.33, 1], ["2.1", 1]]
+        data2 = [[True, 1], ["false", 1]]
+
+        for data in [data1, data2]:
+            cdf = self.connect.createDataFrame(data, ["a", "b"])
+            sdf = self.spark.createDataFrame(data, ["a", "b"])
+
+            self.assertEqual(cdf.schema, sdf.schema)
+            self.assertEqual(cdf.collect(), sdf.collect())
+
     def test_nested_type_create_from_rows(self):
         data1 = [Row(a=1, b=Row(c=2, d=Row(e=3, f=Row(g=4, h=Row(i=5)))))]
         # root
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py 
b/python/pyspark/sql/tests/connect/test_parity_types.py
index d40b3f1a5f2..025e64f2bf0 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -109,16 +109,6 @@ class TypesParityTests(TypesTestsMixin, 
ReusedConnectTestCase):
     def test_infer_schema_to_local(self):
         super().test_infer_schema_to_local()
 
-    # TODO(SPARK-42023): createDataFrame should corse types of string false to 
bool false
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_infer_schema_upcast_boolean_to_string(self):
-        super().test_infer_schema_upcast_boolean_to_string()
-
-    # TODO(SPARK-42024): createDataFrame should corse types of string float to 
float
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_infer_schema_upcast_float_to_string(self):
-        super().test_infer_schema_upcast_float_to_string()
-
     @unittest.skip("Spark Connect does not support RDD but the tests depend on 
them.")
     def test_infer_schema_upcast_int_to_string(self):
         super().test_infer_schema_upcast_int_to_string()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to