This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 514449b7cbf [SPARK-41899][CONNECT][PYTHON] createDataFrame` should respect user provided DDL schema 514449b7cbf is described below commit 514449b7cbfca253773997fdd173dd138fbd2bf4 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sun Jan 8 13:02:45 2023 +0800 [SPARK-41899][CONNECT][PYTHON] createDataFrame` should respect user provided DDL schema ### What changes were proposed in this pull request? Make `createDataFrame` respect user provided DDL schema ### Why are the changes needed? consistency ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added UT and enabled tests Closes #39452 from zhengruifeng/connect_fix_41899. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/session.py | 42 ++++++++++++---------- .../sql/tests/connect/test_connect_basic.py | 10 ++++++ .../sql/tests/connect/test_parity_functions.py | 10 ------ 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 33ec254fe43..bd6e8bd19f3 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -184,7 +184,6 @@ class SparkSession: if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") - table: Optional[pa.Table] = None _schema: Optional[Union[AtomicType, StructType]] = None _schema_str: Optional[str] = None _cols: Optional[List[str]] = None @@ -207,8 +206,11 @@ class SparkSession: else: raise ValueError("can not infer schema from empty dataset") + _table: Optional[pa.Table] = None + _inferred_schema: Optional[StructType] = None + if isinstance(data, pd.DataFrame): - table = pa.Table.from_pandas(data) + _table = pa.Table.from_pandas(data) elif isinstance(data, np.ndarray): if data.ndim not in [1, 2]: @@ -227,7 +229,7 @@ class SparkSession: f"new values have {len(_cols)} elements" ) - table = pa.Table.from_arrays([pa.array(data)], _cols) + _table = pa.Table.from_arrays([pa.array(data)], _cols) else: if data.shape[1] != len(_cols): raise ValueError( @@ -235,7 +237,7 @@ class SparkSession: f"new values have {len(_cols)} elements" ) - table = pa.Table.from_arrays( + _table = pa.Table.from_arrays( [pa.array(data[::, i]) for i in range(0, data.shape[1])], _cols ) @@ -248,35 +250,37 @@ class SparkSession: # For dictionaries, we sort the schema in alphabetical order. _data = [dict(sorted(d.items())) for d in _data] - _schema = self._inferSchemaFromList(_data, _cols) + _inferred_schema = self._inferSchemaFromList(_data, _cols) if _cols is not None: for i, name in enumerate(_cols): - _schema.fields[i].name = name - _schema.names[i] = name + _inferred_schema.fields[i].name = name + _inferred_schema.names[i] = name if _cols is None: - if _schema is None: + if _schema is None and _inferred_schema is None: if isinstance(_data[0], (list, tuple)): _cols = ["_%s" % i for i in range(1, len(_data[0]) + 1)] else: _cols = ["_1"] - elif isinstance(_schema, StructType): + elif _schema is not None and isinstance(_schema, StructType): _cols = _schema.names + elif _inferred_schema is not None: + _cols = _inferred_schema.names else: _cols = ["value"] if isinstance(_data[0], Row): - table = pa.Table.from_pylist([row.asDict(recursive=True) for row in _data]) + _table = pa.Table.from_pylist([row.asDict(recursive=True) for row in _data]) elif isinstance(_data[0], dict): - table = pa.Table.from_pylist(_data) + _table = pa.Table.from_pylist(_data) elif isinstance(_data[0], (list, tuple)): - table = pa.Table.from_pylist([dict(zip(_cols, list(item))) for item in _data]) + _table = pa.Table.from_pylist([dict(zip(_cols, list(item))) for item in _data]) else: # input data can be [1, 2, 3] - table = pa.Table.from_pylist([dict(zip(_cols, [item])) for item in _data]) + _table = pa.Table.from_pylist([dict(zip(_cols, [item])) for item in _data]) # Validate number of columns - num_cols = table.shape[1] + num_cols = _table.shape[1] if ( _schema is not None and isinstance(_schema, StructType) @@ -294,13 +298,15 @@ class SparkSession: ) if _schema is not None: - return DataFrame.withPlan(LocalRelation(table, schema=_schema.json()), self) + return DataFrame.withPlan(LocalRelation(_table, schema=_schema.json()), self) elif _schema_str is not None: - return DataFrame.withPlan(LocalRelation(table, schema=_schema_str), self) + return DataFrame.withPlan(LocalRelation(_table, schema=_schema_str), self) + elif _inferred_schema is not None: + return DataFrame.withPlan(LocalRelation(_table, schema=_inferred_schema.json()), self) elif _cols is not None and len(_cols) > 0: - return DataFrame.withPlan(LocalRelation(table), self).toDF(*_cols) + return DataFrame.withPlan(LocalRelation(_table), self).toDF(*_cols) else: - return DataFrame.withPlan(LocalRelation(table), self) + return DataFrame.withPlan(LocalRelation(_table), self) createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__ diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 235bd2815ae..3f82fdb4f4d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +import datetime import unittest import shutil import tempfile @@ -563,6 +565,14 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): sdf.select(SF.pmod("a", "b")).toPandas(), ) + def test_cast_with_ddl(self): + data = [Row(date=datetime.date(2021, 12, 27), add=2)] + + cdf = self.connect.createDataFrame(data, "date date, add integer") + sdf = self.spark.createDataFrame(data, "date date, add integer") + + self.assertEqual(cdf.schema, sdf.schema) + def test_create_empty_df(self): for schema in [ "STRING", diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py b/python/pyspark/sql/tests/connect/test_parity_functions.py index c5add102c6a..3e46e1caa3e 100644 --- a/python/pyspark/sql/tests/connect/test_parity_functions.py +++ b/python/pyspark/sql/tests/connect/test_parity_functions.py @@ -53,16 +53,6 @@ class FunctionsParityTests(ReusedSQLTestCase, FunctionsTestsMixin): def test_basic_functions(self): super().test_basic_functions() - # TODO(SPARK-41899): DataFrame.createDataFrame converting int to bigint - @unittest.skip("Fails in Spark Connect, should enable.") - def test_date_add_function(self): - super().test_date_add_function() - - # TODO(SPARK-41899): DataFrame.createDataFrame converting int to bigint - @unittest.skip("Fails in Spark Connect, should enable.") - def test_date_sub_function(self): - super().test_date_sub_function() - # TODO(SPARK-41847): DataFrame mapfield,structlist invalid type @unittest.skip("Fails in Spark Connect, should enable.") def test_explode(self): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org