This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 68c3354267d [SPARK-41810][CONNECT] Infer names from a list of 
dictionaries in SparkSession.createDataFrame
68c3354267d is described below

commit 68c3354267d30a96765a6592243205957d2cddf1
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Mon Jan 2 21:24:45 2023 +0900

    [SPARK-41810][CONNECT] Infer names from a list of dictionaries in 
SparkSession.createDataFrame
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to support to infer field names when the input data is the 
list of dictionaries in `SparkSession.createDataFrame`.
    
    For example,
    
    ```python
    spark.createDataFrame([{"course": "dotNET", "earnings": 10000, "year": 
2012}]).show()
    ```
    
    **Before**:
    
    ```
    +------+-----+----+
    |    _1|   _2|  _3|
    +------+-----+----+
    |dotNET|10000|2012|
    +------+-----+----+
    ```
    
    **After**:
    
    ```
    +------+--------+----+
    |course|earnings|year|
    +------+--------+----+
    |dotNET|   10000|2012|
    +------+--------+----+
    ```
    
    ### Why are the changes needed?
    
    To match the behaviour with the regular PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No to end users because Spark Connect has not been released.
    
    ### How was this patch tested?
    
    Unittest was added.
    
    Closes #39344 from HyukjinKwon/SPARK-41746.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/session.py              | 16 +++++++++------
 .../sql/tests/connect/test_connect_basic.py        | 24 ++++++++++++----------
 2 files changed, 23 insertions(+), 17 deletions(-)

diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index a461372c08c..0233bde1c17 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -218,15 +218,21 @@ class SparkSession:
 
         else:
             _data = list(data)
-            pdf = pd.DataFrame(_data)
 
-            if _schema is None and isinstance(_data[0], Row):
+            if _schema is None and (isinstance(_data[0], Row) or 
isinstance(_data[0], dict)):
+                if isinstance(_data[0], dict):
+                    # Sort the data to respect inferred schema.
+                    # For dictionaries, we sort the schema in alphabetical 
order.
+                    _data = [dict(sorted(d.items())) for d in _data]
+
                 _schema = self._inferSchemaFromList(_data, _cols)
                 if _cols is not None:
                     for i, name in enumerate(_cols):
                         _schema.fields[i].name = name
                         _schema.names[i] = name
 
+            pdf = pd.DataFrame(_data)
+
             if _cols is None:
                 _cols = ["_%s" % i for i in range(1, pdf.shape[1] + 1)]
 
@@ -342,11 +348,9 @@ def _test() -> None:
         # Spark Connect does not support to set master together.
         pyspark.sql.connect.session.SparkSession.__doc__ = None
         del pyspark.sql.connect.session.SparkSession.Builder.master.__doc__
-
-        # TODO(SPARK-41746): SparkSession.createDataFrame does not respect the 
column names in
-        #  dictionary
+        # RDD API is not supported in Spark Connect.
         del pyspark.sql.connect.session.SparkSession.createDataFrame.__doc__
-        del pyspark.sql.connect.session.SparkSession.read.__doc__
+
         # TODO(SPARK-41811): Implement SparkSession.sql's string formatter
         del pyspark.sql.connect.session.SparkSession.sql.__doc__
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 6a65e412dfd..7c17c5f6820 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -389,8 +389,8 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.connect.createDataFrame(data, "col1 int, col2 int, col3 
int").show()
 
     def test_with_local_rows(self):
-        # SPARK-41789: Test creating a dataframe with list of Rows
-        data = [
+        # SPARK-41789, SPARK-41810: Test creating a dataframe with list of 
rows and dictionaries
+        rows = [
             Row(course="dotNET", year=2012, earnings=10000),
             Row(course="Java", year=2012, earnings=20000),
             Row(course="dotNET", year=2012, earnings=5000),
@@ -398,19 +398,21 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             Row(course="Java", year=2013, earnings=30000),
             Row(course="Scala", year=2022, earnings=None),
         ]
+        dicts = [row.asDict() for row in rows]
 
-        sdf = self.spark.createDataFrame(data)
-        cdf = self.connect.createDataFrame(data)
+        for data in [rows, dicts]:
+            sdf = self.spark.createDataFrame(data)
+            cdf = self.connect.createDataFrame(data)
 
-        self.assertEqual(sdf.schema, cdf.schema)
-        self.assert_eq(sdf.toPandas(), cdf.toPandas())
+            self.assertEqual(sdf.schema, cdf.schema)
+            self.assert_eq(sdf.toPandas(), cdf.toPandas())
 
-        # test with rename
-        sdf = self.spark.createDataFrame(data, schema=["a", "b", "c"])
-        cdf = self.connect.createDataFrame(data, schema=["a", "b", "c"])
+            # test with rename
+            sdf = self.spark.createDataFrame(data, schema=["a", "b", "c"])
+            cdf = self.connect.createDataFrame(data, schema=["a", "b", "c"])
 
-        self.assertEqual(sdf.schema, cdf.schema)
-        self.assert_eq(sdf.toPandas(), cdf.toPandas())
+            self.assertEqual(sdf.schema, cdf.schema)
+            self.assert_eq(sdf.toPandas(), cdf.toPandas())
 
     def test_with_atom_type(self):
         for data in [[(1), (2), (3)], [1, 2, 3]]:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to