This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 514449b7cbf [SPARK-41899][CONNECT][PYTHON] createDataFrame` should 
respect user provided DDL schema
514449b7cbf is described below

commit 514449b7cbfca253773997fdd173dd138fbd2bf4
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Sun Jan 8 13:02:45 2023 +0800

    [SPARK-41899][CONNECT][PYTHON] createDataFrame` should respect user 
provided DDL schema
    
    ### What changes were proposed in this pull request?
     Make `createDataFrame` respect user provided DDL schema
    
    ### Why are the changes needed?
    consistency
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added UT and enabled tests
    
    Closes #39452 from zhengruifeng/connect_fix_41899.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/session.py              | 42 ++++++++++++----------
 .../sql/tests/connect/test_connect_basic.py        | 10 ++++++
 .../sql/tests/connect/test_parity_functions.py     | 10 ------
 3 files changed, 34 insertions(+), 28 deletions(-)

diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 33ec254fe43..bd6e8bd19f3 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -184,7 +184,6 @@ class SparkSession:
         if isinstance(data, DataFrame):
             raise TypeError("data is already a DataFrame")
 
-        table: Optional[pa.Table] = None
         _schema: Optional[Union[AtomicType, StructType]] = None
         _schema_str: Optional[str] = None
         _cols: Optional[List[str]] = None
@@ -207,8 +206,11 @@ class SparkSession:
             else:
                 raise ValueError("can not infer schema from empty dataset")
 
+        _table: Optional[pa.Table] = None
+        _inferred_schema: Optional[StructType] = None
+
         if isinstance(data, pd.DataFrame):
-            table = pa.Table.from_pandas(data)
+            _table = pa.Table.from_pandas(data)
 
         elif isinstance(data, np.ndarray):
             if data.ndim not in [1, 2]:
@@ -227,7 +229,7 @@ class SparkSession:
                         f"new values have {len(_cols)} elements"
                     )
 
-                table = pa.Table.from_arrays([pa.array(data)], _cols)
+                _table = pa.Table.from_arrays([pa.array(data)], _cols)
             else:
                 if data.shape[1] != len(_cols):
                     raise ValueError(
@@ -235,7 +237,7 @@ class SparkSession:
                         f"new values have {len(_cols)} elements"
                     )
 
-                table = pa.Table.from_arrays(
+                _table = pa.Table.from_arrays(
                     [pa.array(data[::, i]) for i in range(0, data.shape[1])], 
_cols
                 )
 
@@ -248,35 +250,37 @@ class SparkSession:
                     # For dictionaries, we sort the schema in alphabetical 
order.
                     _data = [dict(sorted(d.items())) for d in _data]
 
-                _schema = self._inferSchemaFromList(_data, _cols)
+                _inferred_schema = self._inferSchemaFromList(_data, _cols)
                 if _cols is not None:
                     for i, name in enumerate(_cols):
-                        _schema.fields[i].name = name
-                        _schema.names[i] = name
+                        _inferred_schema.fields[i].name = name
+                        _inferred_schema.names[i] = name
 
             if _cols is None:
-                if _schema is None:
+                if _schema is None and _inferred_schema is None:
                     if isinstance(_data[0], (list, tuple)):
                         _cols = ["_%s" % i for i in range(1, len(_data[0]) + 
1)]
                     else:
                         _cols = ["_1"]
-                elif isinstance(_schema, StructType):
+                elif _schema is not None and isinstance(_schema, StructType):
                     _cols = _schema.names
+                elif _inferred_schema is not None:
+                    _cols = _inferred_schema.names
                 else:
                     _cols = ["value"]
 
             if isinstance(_data[0], Row):
-                table = pa.Table.from_pylist([row.asDict(recursive=True) for 
row in _data])
+                _table = pa.Table.from_pylist([row.asDict(recursive=True) for 
row in _data])
             elif isinstance(_data[0], dict):
-                table = pa.Table.from_pylist(_data)
+                _table = pa.Table.from_pylist(_data)
             elif isinstance(_data[0], (list, tuple)):
-                table = pa.Table.from_pylist([dict(zip(_cols, list(item))) for 
item in _data])
+                _table = pa.Table.from_pylist([dict(zip(_cols, list(item))) 
for item in _data])
             else:
                 # input data can be [1, 2, 3]
-                table = pa.Table.from_pylist([dict(zip(_cols, [item])) for 
item in _data])
+                _table = pa.Table.from_pylist([dict(zip(_cols, [item])) for 
item in _data])
 
         # Validate number of columns
-        num_cols = table.shape[1]
+        num_cols = _table.shape[1]
         if (
             _schema is not None
             and isinstance(_schema, StructType)
@@ -294,13 +298,15 @@ class SparkSession:
             )
 
         if _schema is not None:
-            return DataFrame.withPlan(LocalRelation(table, 
schema=_schema.json()), self)
+            return DataFrame.withPlan(LocalRelation(_table, 
schema=_schema.json()), self)
         elif _schema_str is not None:
-            return DataFrame.withPlan(LocalRelation(table, 
schema=_schema_str), self)
+            return DataFrame.withPlan(LocalRelation(_table, 
schema=_schema_str), self)
+        elif _inferred_schema is not None:
+            return DataFrame.withPlan(LocalRelation(_table, 
schema=_inferred_schema.json()), self)
         elif _cols is not None and len(_cols) > 0:
-            return DataFrame.withPlan(LocalRelation(table), self).toDF(*_cols)
+            return DataFrame.withPlan(LocalRelation(_table), self).toDF(*_cols)
         else:
-            return DataFrame.withPlan(LocalRelation(table), self)
+            return DataFrame.withPlan(LocalRelation(_table), self)
 
     createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 235bd2815ae..3f82fdb4f4d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+
+import datetime
 import unittest
 import shutil
 import tempfile
@@ -563,6 +565,14 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             sdf.select(SF.pmod("a", "b")).toPandas(),
         )
 
+    def test_cast_with_ddl(self):
+        data = [Row(date=datetime.date(2021, 12, 27), add=2)]
+
+        cdf = self.connect.createDataFrame(data, "date date, add integer")
+        sdf = self.spark.createDataFrame(data, "date date, add integer")
+
+        self.assertEqual(cdf.schema, sdf.schema)
+
     def test_create_empty_df(self):
         for schema in [
             "STRING",
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py 
b/python/pyspark/sql/tests/connect/test_parity_functions.py
index c5add102c6a..3e46e1caa3e 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -53,16 +53,6 @@ class FunctionsParityTests(ReusedSQLTestCase, 
FunctionsTestsMixin):
     def test_basic_functions(self):
         super().test_basic_functions()
 
-    # TODO(SPARK-41899): DataFrame.createDataFrame converting int to bigint
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_date_add_function(self):
-        super().test_date_add_function()
-
-    # TODO(SPARK-41899): DataFrame.createDataFrame converting int to bigint
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_date_sub_function(self):
-        super().test_date_sub_function()
-
     # TODO(SPARK-41847): DataFrame mapfield,structlist invalid type
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_explode(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to