This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 68c3354267d [SPARK-41810][CONNECT] Infer names from a list of dictionaries in SparkSession.createDataFrame 68c3354267d is described below commit 68c3354267d30a96765a6592243205957d2cddf1 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Mon Jan 2 21:24:45 2023 +0900 [SPARK-41810][CONNECT] Infer names from a list of dictionaries in SparkSession.createDataFrame ### What changes were proposed in this pull request? This PR proposes to support to infer field names when the input data is the list of dictionaries in `SparkSession.createDataFrame`. For example, ```python spark.createDataFrame([{"course": "dotNET", "earnings": 10000, "year": 2012}]).show() ``` **Before**: ``` +------+-----+----+ | _1| _2| _3| +------+-----+----+ |dotNET|10000|2012| +------+-----+----+ ``` **After**: ``` +------+--------+----+ |course|earnings|year| +------+--------+----+ |dotNET| 10000|2012| +------+--------+----+ ``` ### Why are the changes needed? To match the behaviour with the regular PySpark. ### Does this PR introduce _any_ user-facing change? No to end users because Spark Connect has not been released. ### How was this patch tested? Unittest was added. Closes #39344 from HyukjinKwon/SPARK-41746. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/session.py | 16 +++++++++------ .../sql/tests/connect/test_connect_basic.py | 24 ++++++++++++---------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index a461372c08c..0233bde1c17 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -218,15 +218,21 @@ class SparkSession: else: _data = list(data) - pdf = pd.DataFrame(_data) - if _schema is None and isinstance(_data[0], Row): + if _schema is None and (isinstance(_data[0], Row) or isinstance(_data[0], dict)): + if isinstance(_data[0], dict): + # Sort the data to respect inferred schema. + # For dictionaries, we sort the schema in alphabetical order. + _data = [dict(sorted(d.items())) for d in _data] + _schema = self._inferSchemaFromList(_data, _cols) if _cols is not None: for i, name in enumerate(_cols): _schema.fields[i].name = name _schema.names[i] = name + pdf = pd.DataFrame(_data) + if _cols is None: _cols = ["_%s" % i for i in range(1, pdf.shape[1] + 1)] @@ -342,11 +348,9 @@ def _test() -> None: # Spark Connect does not support to set master together. pyspark.sql.connect.session.SparkSession.__doc__ = None del pyspark.sql.connect.session.SparkSession.Builder.master.__doc__ - - # TODO(SPARK-41746): SparkSession.createDataFrame does not respect the column names in - # dictionary + # RDD API is not supported in Spark Connect. del pyspark.sql.connect.session.SparkSession.createDataFrame.__doc__ - del pyspark.sql.connect.session.SparkSession.read.__doc__ + # TODO(SPARK-41811): Implement SparkSession.sql's string formatter del pyspark.sql.connect.session.SparkSession.sql.__doc__ diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 6a65e412dfd..7c17c5f6820 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -389,8 +389,8 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.connect.createDataFrame(data, "col1 int, col2 int, col3 int").show() def test_with_local_rows(self): - # SPARK-41789: Test creating a dataframe with list of Rows - data = [ + # SPARK-41789, SPARK-41810: Test creating a dataframe with list of rows and dictionaries + rows = [ Row(course="dotNET", year=2012, earnings=10000), Row(course="Java", year=2012, earnings=20000), Row(course="dotNET", year=2012, earnings=5000), @@ -398,19 +398,21 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): Row(course="Java", year=2013, earnings=30000), Row(course="Scala", year=2022, earnings=None), ] + dicts = [row.asDict() for row in rows] - sdf = self.spark.createDataFrame(data) - cdf = self.connect.createDataFrame(data) + for data in [rows, dicts]: + sdf = self.spark.createDataFrame(data) + cdf = self.connect.createDataFrame(data) - self.assertEqual(sdf.schema, cdf.schema) - self.assert_eq(sdf.toPandas(), cdf.toPandas()) + self.assertEqual(sdf.schema, cdf.schema) + self.assert_eq(sdf.toPandas(), cdf.toPandas()) - # test with rename - sdf = self.spark.createDataFrame(data, schema=["a", "b", "c"]) - cdf = self.connect.createDataFrame(data, schema=["a", "b", "c"]) + # test with rename + sdf = self.spark.createDataFrame(data, schema=["a", "b", "c"]) + cdf = self.connect.createDataFrame(data, schema=["a", "b", "c"]) - self.assertEqual(sdf.schema, cdf.schema) - self.assert_eq(sdf.toPandas(), cdf.toPandas()) + self.assertEqual(sdf.schema, cdf.schema) + self.assert_eq(sdf.toPandas(), cdf.toPandas()) def test_with_atom_type(self): for data in [[(1), (2), (3)], [1, 2, 3]]: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org