This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 4d0629942b8 [SPARK-42900][CONNECT][PYTHON] Fix createDataFrame to 
respect inference and column names
4d0629942b8 is described below

commit 4d0629942b8a6b5295fd1c2f5693a31b6bfcdddd
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Thu Mar 23 16:45:30 2023 +0800

    [SPARK-42900][CONNECT][PYTHON] Fix createDataFrame to respect inference and 
column names
    
    ### What changes were proposed in this pull request?
    
    Fixes `createDataFrame` to respect inference and column names.
    
    ### Why are the changes needed?
    
    Currently when a column name list is provided as a schema, the type 
inference result is not taken care of.
    
    As a result, `createDataFrame` from UDT objects with column name list 
doesn't take the UDT type.
    
    For example:
    
    ```py
    >>> from pyspark.ml.linalg import Vectors
    >>> df = spark.createDataFrame([(1.0, 1.0, Vectors.dense(0.0, 5.0)), (0.0, 
2.0, Vectors.dense(1.0, 2.0))], ["label", "weight", "features"])
    >>> df.printSchema()
    root
     |-- label: double (nullable = true)
     |-- weight: double (nullable = true)
     |-- features: struct (nullable = true)
     |    |-- type: byte (nullable = false)
     |    |-- size: integer (nullable = true)
     |    |-- indices: array (nullable = true)
     |    |    |-- element: integer (containsNull = false)
     |    |-- values: array (nullable = true)
     |    |    |-- element: double (containsNull = false)
    ```
    
    , which should be:
    
    ```py
    >>> df.printSchema()
    root
     |-- label: double (nullable = true)
     |-- weight: double (nullable = true)
     |-- features: vector (nullable = true)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    Closes #40527 from ueshin/issues/SPARK-42900/cols.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
    (cherry picked from commit 39e87d66b07beff91aebed6163ee82a35fbd1fcf)
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  | 24 ++++++-----
 python/pyspark/sql/connect/session.py              | 17 +++++---
 .../sql/tests/connect/test_connect_basic.py        | 18 +++++----
 python/pyspark/sql/tests/test_types.py             | 46 ++++++++++++++++++++++
 4 files changed, 81 insertions(+), 24 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 9ebaed44820..ca219214c66 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -695,26 +695,30 @@ class SparkConnectPlanner(val session: SparkSession) {
       if (schema == null) {
         logical.LocalRelation(attributes, data.map(_.copy()).toSeq)
       } else {
-        def udtToSqlType(dt: DataType): DataType = dt match {
-          case udt: UserDefinedType[_] => udt.sqlType
+        def normalize(dt: DataType): DataType = dt match {
+          case udt: UserDefinedType[_] => normalize(udt.sqlType)
           case StructType(fields) =>
-            val newFields = fields.map { case StructField(name, dataType, 
nullable, metadata) =>
-              StructField(name, udtToSqlType(dataType), nullable, metadata)
+            val newFields = fields.zipWithIndex.map {
+              case (StructField(_, dataType, nullable, metadata), i) =>
+                StructField(s"col_$i", normalize(dataType), nullable, metadata)
             }
             StructType(newFields)
           case ArrayType(elementType, containsNull) =>
-            ArrayType(udtToSqlType(elementType), containsNull)
+            ArrayType(normalize(elementType), containsNull)
           case MapType(keyType, valueType, valueContainsNull) =>
-            MapType(udtToSqlType(keyType), udtToSqlType(valueType), 
valueContainsNull)
+            MapType(normalize(keyType), normalize(valueType), 
valueContainsNull)
           case _ => dt
         }
 
-        val sqlTypeOnlySchema = udtToSqlType(schema).asInstanceOf[StructType]
+        val normalized = normalize(schema).asInstanceOf[StructType]
 
         val project = Dataset
-          .ofRows(session, logicalPlan = logical.LocalRelation(attributes))
-          .toDF(sqlTypeOnlySchema.names: _*)
-          .to(sqlTypeOnlySchema)
+          .ofRows(
+            session,
+            logicalPlan =
+              
logical.LocalRelation(normalize(structType).asInstanceOf[StructType].toAttributes))
+          .toDF(normalized.names: _*)
+          .to(normalized)
           .logicalPlan
           .asInstanceOf[Project]
 
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index b20d1cd2325..b49f6df969c 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -303,6 +303,9 @@ class SparkSession:
                     [pa.array(data[::, i]) for i in range(0, data.shape[1])], 
_cols
                 )
 
+            # The _table should already have the proper column names.
+            _cols = None
+
         else:
             _data = list(data)
 
@@ -348,7 +351,7 @@ class SparkSession:
                             "a StructType Schema is required in this case"
                         )
 
-                if _schema_str is None and _cols is None:
+                if _schema_str is None:
                     _schema = _inferred_schema
 
             from pyspark.sql.connect.conversion import 
LocalDataToArrowConversion
@@ -367,13 +370,15 @@ class SparkSession:
             )
 
         if _schema is not None:
-            return DataFrame.withPlan(LocalRelation(_table, 
schema=_schema.json()), self)
+            df = DataFrame.withPlan(LocalRelation(_table, 
schema=_schema.json()), self)
         elif _schema_str is not None:
-            return DataFrame.withPlan(LocalRelation(_table, 
schema=_schema_str), self)
-        elif _cols is not None and len(_cols) > 0:
-            return DataFrame.withPlan(LocalRelation(_table), self).toDF(*_cols)
+            df = DataFrame.withPlan(LocalRelation(_table, schema=_schema_str), 
self)
         else:
-            return DataFrame.withPlan(LocalRelation(_table), self)
+            df = DataFrame.withPlan(LocalRelation(_table), self)
+
+        if _cols is not None and len(_cols) > 0:
+            df = df.toDF(*_cols)
+        return df
 
     createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index f911ca9ba78..1cac84659b3 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -522,11 +522,12 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             ["a", "b", "c", "d"],
             ("x1", "x2", "x3", "x4"),
         ]:
-            sdf = self.spark.createDataFrame(data, schema=schema)
-            cdf = self.connect.createDataFrame(data, schema=schema)
+            with self.subTest(schema=schema):
+                sdf = self.spark.createDataFrame(data, schema=schema)
+                cdf = self.connect.createDataFrame(data, schema=schema)
 
-            self.assertEqual(sdf.schema, cdf.schema)
-            self.assert_eq(sdf.toPandas(), cdf.toPandas())
+                self.assertEqual(sdf.schema, cdf.schema)
+                self.assert_eq(sdf.toPandas(), cdf.toPandas())
 
         with self.assertRaisesRegex(
             ValueError,
@@ -897,11 +898,12 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         # |    |    |    |-- value: long (valueContainsNull = true)
 
         for data in [data1, data2, data3, data4, data5]:
-            cdf = self.connect.createDataFrame(data)
-            sdf = self.spark.createDataFrame(data)
+            with self.subTest(data=data):
+                cdf = self.connect.createDataFrame(data)
+                sdf = self.spark.createDataFrame(data)
 
-            self.assertEqual(cdf.schema, sdf.schema)
-            self.assertEqual(cdf.collect(), sdf.collect())
+                self.assertEqual(cdf.schema, sdf.schema)
+                self.assertEqual(cdf.collect(), sdf.collect())
 
     def test_create_df_from_objects(self):
         data = [MyObject(1, "1"), MyObject(2, "2")]
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index aaac43cdf67..bee899e928e 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -595,6 +595,29 @@ class TypesTestsMixin:
             point = self.spark.sql("SELECT point FROM 
labeled_point").head().point
             self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
 
+    def test_infer_schema_with_udt_with_column_names(self):
+        row = (1.0, ExamplePoint(1.0, 2.0))
+        df = self.spark.createDataFrame([row], ["label", "point"])
+        schema = df.schema
+        field = [f for f in schema.fields if f.name == "point"][0]
+        self.assertEqual(type(field.dataType), ExamplePointUDT)
+
+        with self.tempView("labeled_point"):
+            df.createOrReplaceTempView("labeled_point")
+            point = self.spark.sql("SELECT point FROM 
labeled_point").head().point
+            self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+        row = (1.0, PythonOnlyPoint(1.0, 2.0))
+        df = self.spark.createDataFrame([row], ["label", "point"])
+        schema = df.schema
+        field = [f for f in schema.fields if f.name == "point"][0]
+        self.assertEqual(type(field.dataType), PythonOnlyUDT)
+
+        with self.tempView("labeled_point"):
+            df.createOrReplaceTempView("labeled_point")
+            point = self.spark.sql("SELECT point FROM 
labeled_point").head().point
+            self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
+
     def test_apply_schema_with_udt(self):
         row = (1.0, ExamplePoint(1.0, 2.0))
         schema = StructType(
@@ -618,6 +641,29 @@ class TypesTestsMixin:
         point = df.head().point
         self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
 
+    def test_apply_schema_with_nullable_udt(self):
+        rows = [(1.0, ExamplePoint(1.0, 2.0)), (2.0, None)]
+        schema = StructType(
+            [
+                StructField("label", DoubleType(), False),
+                StructField("point", ExamplePointUDT(), True),
+            ]
+        )
+        df = self.spark.createDataFrame(rows, schema)
+        points = [row.point for row in df.collect()]
+        self.assertEqual(points, [ExamplePoint(1.0, 2.0), None])
+
+        rows = [(1.0, PythonOnlyPoint(1.0, 2.0)), (2.0, None)]
+        schema = StructType(
+            [
+                StructField("label", DoubleType(), False),
+                StructField("point", PythonOnlyUDT(), True),
+            ]
+        )
+        df = self.spark.createDataFrame(rows, schema)
+        points = [row.point for row in df.collect()]
+        self.assertEqual(points, [PythonOnlyPoint(1.0, 2.0), None])
+
     def test_udf_with_udt(self):
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
         df = self.spark.createDataFrame([row])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to