This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 39e87d66b07 [SPARK-42900][CONNECT][PYTHON] Fix createDataFrame to respect inference and column names 39e87d66b07 is described below commit 39e87d66b07beff91aebed6163ee82a35fbd1fcf Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Thu Mar 23 16:45:30 2023 +0800 [SPARK-42900][CONNECT][PYTHON] Fix createDataFrame to respect inference and column names ### What changes were proposed in this pull request? Fixes `createDataFrame` to respect inference and column names. ### Why are the changes needed? Currently when a column name list is provided as a schema, the type inference result is not taken care of. As a result, `createDataFrame` from UDT objects with column name list doesn't take the UDT type. For example: ```py >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(1.0, 1.0, Vectors.dense(0.0, 5.0)), (0.0, 2.0, Vectors.dense(1.0, 2.0))], ["label", "weight", "features"]) >>> df.printSchema() root |-- label: double (nullable = true) |-- weight: double (nullable = true) |-- features: struct (nullable = true) | |-- type: byte (nullable = false) | |-- size: integer (nullable = true) | |-- indices: array (nullable = true) | | |-- element: integer (containsNull = false) | |-- values: array (nullable = true) | | |-- element: double (containsNull = false) ``` , which should be: ```py >>> df.printSchema() root |-- label: double (nullable = true) |-- weight: double (nullable = true) |-- features: vector (nullable = true) ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added the related tests. Closes #40527 from ueshin/issues/SPARK-42900/cols. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../sql/connect/planner/SparkConnectPlanner.scala | 24 ++++++----- python/pyspark/sql/connect/session.py | 17 +++++--- .../sql/tests/connect/test_connect_basic.py | 18 +++++---- python/pyspark/sql/tests/test_types.py | 46 ++++++++++++++++++++++ 4 files changed, 81 insertions(+), 24 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 4f142beaf67..f6fee4250b8 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -718,26 +718,30 @@ class SparkConnectPlanner(val session: SparkSession) { if (schema == null) { logical.LocalRelation(attributes, data.map(_.copy()).toSeq) } else { - def udtToSqlType(dt: DataType): DataType = dt match { - case udt: UserDefinedType[_] => udt.sqlType + def normalize(dt: DataType): DataType = dt match { + case udt: UserDefinedType[_] => normalize(udt.sqlType) case StructType(fields) => - val newFields = fields.map { case StructField(name, dataType, nullable, metadata) => - StructField(name, udtToSqlType(dataType), nullable, metadata) + val newFields = fields.zipWithIndex.map { + case (StructField(_, dataType, nullable, metadata), i) => + StructField(s"col_$i", normalize(dataType), nullable, metadata) } StructType(newFields) case ArrayType(elementType, containsNull) => - ArrayType(udtToSqlType(elementType), containsNull) + ArrayType(normalize(elementType), containsNull) case MapType(keyType, valueType, valueContainsNull) => - MapType(udtToSqlType(keyType), udtToSqlType(valueType), valueContainsNull) + MapType(normalize(keyType), normalize(valueType), valueContainsNull) case _ => dt } - val sqlTypeOnlySchema = udtToSqlType(schema).asInstanceOf[StructType] + val normalized = normalize(schema).asInstanceOf[StructType] val project = Dataset - .ofRows(session, logicalPlan = logical.LocalRelation(attributes)) - .toDF(sqlTypeOnlySchema.names: _*) - .to(sqlTypeOnlySchema) + .ofRows( + session, + logicalPlan = + logical.LocalRelation(normalize(structType).asInstanceOf[StructType].toAttributes)) + .toDF(normalized.names: _*) + .to(normalized) .logicalPlan .asInstanceOf[Project] diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 8fe5020f4a4..4bd5b26765d 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -312,6 +312,9 @@ class SparkSession: [pa.array(data[::, i]) for i in range(0, data.shape[1])], _cols ) + # The _table should already have the proper column names. + _cols = None + else: _data = list(data) @@ -357,7 +360,7 @@ class SparkSession: "a StructType Schema is required in this case" ) - if _schema_str is None and _cols is None: + if _schema_str is None: _schema = _inferred_schema from pyspark.sql.connect.conversion import LocalDataToArrowConversion @@ -376,13 +379,15 @@ class SparkSession: ) if _schema is not None: - return DataFrame.withPlan(LocalRelation(_table, schema=_schema.json()), self) + df = DataFrame.withPlan(LocalRelation(_table, schema=_schema.json()), self) elif _schema_str is not None: - return DataFrame.withPlan(LocalRelation(_table, schema=_schema_str), self) - elif _cols is not None and len(_cols) > 0: - return DataFrame.withPlan(LocalRelation(_table), self).toDF(*_cols) + df = DataFrame.withPlan(LocalRelation(_table, schema=_schema_str), self) else: - return DataFrame.withPlan(LocalRelation(_table), self) + df = DataFrame.withPlan(LocalRelation(_table), self) + + if _cols is not None and len(_cols) > 0: + df = df.toDF(*_cols) + return df createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__ diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 682b3471a74..79c8dba537c 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -522,11 +522,12 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): ["a", "b", "c", "d"], ("x1", "x2", "x3", "x4"), ]: - sdf = self.spark.createDataFrame(data, schema=schema) - cdf = self.connect.createDataFrame(data, schema=schema) + with self.subTest(schema=schema): + sdf = self.spark.createDataFrame(data, schema=schema) + cdf = self.connect.createDataFrame(data, schema=schema) - self.assertEqual(sdf.schema, cdf.schema) - self.assert_eq(sdf.toPandas(), cdf.toPandas()) + self.assertEqual(sdf.schema, cdf.schema) + self.assert_eq(sdf.toPandas(), cdf.toPandas()) with self.assertRaisesRegex( ValueError, @@ -897,11 +898,12 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): # | | | |-- value: long (valueContainsNull = true) for data in [data1, data2, data3, data4, data5]: - cdf = self.connect.createDataFrame(data) - sdf = self.spark.createDataFrame(data) + with self.subTest(data=data): + cdf = self.connect.createDataFrame(data) + sdf = self.spark.createDataFrame(data) - self.assertEqual(cdf.schema, sdf.schema) - self.assertEqual(cdf.collect(), sdf.collect()) + self.assertEqual(cdf.schema, sdf.schema) + self.assertEqual(cdf.collect(), sdf.collect()) def test_create_df_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index aaac43cdf67..bee899e928e 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -595,6 +595,29 @@ class TypesTestsMixin: point = self.spark.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_infer_schema_with_udt_with_column_names(self): + row = (1.0, ExamplePoint(1.0, 2.0)) + df = self.spark.createDataFrame([row], ["label", "point"]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + + with self.tempView("labeled_point"): + df.createOrReplaceTempView("labeled_point") + point = self.spark.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + row = (1.0, PythonOnlyPoint(1.0, 2.0)) + df = self.spark.createDataFrame([row], ["label", "point"]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), PythonOnlyUDT) + + with self.tempView("labeled_point"): + df.createOrReplaceTempView("labeled_point") + point = self.spark.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_apply_schema_with_udt(self): row = (1.0, ExamplePoint(1.0, 2.0)) schema = StructType( @@ -618,6 +641,29 @@ class TypesTestsMixin: point = df.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_apply_schema_with_nullable_udt(self): + rows = [(1.0, ExamplePoint(1.0, 2.0)), (2.0, None)] + schema = StructType( + [ + StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), True), + ] + ) + df = self.spark.createDataFrame(rows, schema) + points = [row.point for row in df.collect()] + self.assertEqual(points, [ExamplePoint(1.0, 2.0), None]) + + rows = [(1.0, PythonOnlyPoint(1.0, 2.0)), (2.0, None)] + schema = StructType( + [ + StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), True), + ] + ) + df = self.spark.createDataFrame(rows, schema) + points = [row.point for row in df.collect()] + self.assertEqual(points, [PythonOnlyPoint(1.0, 2.0), None]) + def test_udf_with_udt(self): row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org