This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cfbed998530e Revert "[SPARK-48322][SPARK-42965][SQL][CONNECT][PYTHON] 
Drop internal metadata in `DataFrame.schema`"
cfbed998530e is described below

commit cfbed998530efaaf17f36d99a9462376eaa7d2ad
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed May 29 20:44:36 2024 +0800

    Revert "[SPARK-48322][SPARK-42965][SQL][CONNECT][PYTHON] Drop internal 
metadata in `DataFrame.schema`"
    
    revert https://github.com/apache/spark/pull/46636
    
    https://github.com/apache/spark/pull/46636#issuecomment-2137321359
    
    Closes #46790 from zhengruifeng/revert_metadata_drop.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/pandas/internal.py                  | 37 +++++++++++++++++-----
 .../sql/tests/connect/test_connect_function.py     |  4 ++-
 python/pyspark/sql/types.py                        | 13 ++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  4 +--
 .../apache/spark/sql/DataFrameAggregateSuite.scala |  5 +--
 5 files changed, 50 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/pandas/internal.py 
b/python/pyspark/pandas/internal.py
index fd0f28e50b2f..04285aa2d879 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -33,6 +33,7 @@ from pyspark.sql import (
     Window,
 )
 from pyspark.sql.types import (  # noqa: F401
+    _drop_metadata,
     BooleanType,
     DataType,
     LongType,
@@ -756,10 +757,20 @@ class InternalFrame:
 
         if is_testing():
             struct_fields = 
spark_frame.select(index_spark_columns).schema.fields
-            assert all(
-                index_field.struct_field == struct_field
-                for index_field, struct_field in zip(index_fields, 
struct_fields)
-            ), (index_fields, struct_fields)
+            if is_remote():
+                # TODO(SPARK-42965): For some reason, the metadata of 
StructField is different
+                # in a few tests when using Spark Connect. However, the 
function works properly.
+                # Therefore, we temporarily perform Spark Connect tests by 
excluding metadata
+                # until the issue is resolved.
+                assert all(
+                    _drop_metadata(index_field.struct_field) == 
_drop_metadata(struct_field)
+                    for index_field, struct_field in zip(index_fields, 
struct_fields)
+                ), (index_fields, struct_fields)
+            else:
+                assert all(
+                    index_field.struct_field == struct_field
+                    for index_field, struct_field in zip(index_fields, 
struct_fields)
+                ), (index_fields, struct_fields)
 
         self._index_fields: List[InternalField] = index_fields
 
@@ -774,10 +785,20 @@ class InternalFrame:
 
         if is_testing():
             struct_fields = 
spark_frame.select(data_spark_columns).schema.fields
-            assert all(
-                data_field.struct_field == struct_field
-                for data_field, struct_field in zip(data_fields, struct_fields)
-            ), (data_fields, struct_fields)
+            if is_remote():
+                # TODO(SPARK-42965): For some reason, the metadata of 
StructField is different
+                # in a few tests when using Spark Connect. However, the 
function works properly.
+                # Therefore, we temporarily perform Spark Connect tests by 
excluding metadata
+                # until the issue is resolved.
+                assert all(
+                    _drop_metadata(data_field.struct_field) == 
_drop_metadata(struct_field)
+                    for data_field, struct_field in zip(data_fields, 
struct_fields)
+                ), (data_fields, struct_fields)
+            else:
+                assert all(
+                    data_field.struct_field == struct_field
+                    for data_field, struct_field in zip(data_fields, 
struct_fields)
+                ), (data_fields, struct_fields)
 
         self._data_fields: List[InternalField] = data_fields
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index 1fb0195b5203..0f0abfd4b856 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -22,6 +22,7 @@ from pyspark.util import is_remote_only
 from pyspark.errors import PySparkTypeError, PySparkValueError
 from pyspark.sql import SparkSession as PySparkSession
 from pyspark.sql.types import (
+    _drop_metadata,
     StringType,
     StructType,
     StructField,
@@ -1673,7 +1674,8 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, 
PandasOnSparkTestUtils, S
             )
         )
 
-        self.assertEqual(cdf.schema, sdf.schema)
+        # TODO: 'cdf.schema' has an extra metadata '{'__autoGeneratedAlias': 
'true'}'
+        self.assertEqual(_drop_metadata(cdf.schema), 
_drop_metadata(sdf.schema))
         self.assertEqual(cdf.collect(), sdf.collect())
 
     def test_csv_functions(self):
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 62f09e948792..c72ff72ce426 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1770,6 +1770,19 @@ _INTERVAL_YEARMONTH = re.compile(r"interval 
(year|month)( to (year|month))?")
 _COLLATIONS_METADATA_KEY = "__COLLATIONS"
 
 
+def _drop_metadata(d: Union[DataType, StructField]) -> Union[DataType, 
StructField]:
+    assert isinstance(d, (DataType, StructField))
+    if isinstance(d, StructField):
+        return StructField(d.name, _drop_metadata(d.dataType), d.nullable, 
None)
+    elif isinstance(d, StructType):
+        return StructType([cast(StructField, _drop_metadata(f)) for f in 
d.fields])
+    elif isinstance(d, ArrayType):
+        return ArrayType(_drop_metadata(d.elementType), d.containsNull)
+    elif isinstance(d, MapType):
+        return MapType(_drop_metadata(d.keyType), _drop_metadata(d.valueType), 
d.valueContainsNull)
+    return d
+
+
 def _parse_datatype_string(s: str) -> DataType:
     """
     Parses the given data type string to a :class:`DataType`. The data type 
string format equals
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index afde54fc3d11..c7511737b2b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
-import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, 
CharVarcharUtils, IntervalUtils}
+import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils}
 import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.execution._
@@ -561,7 +561,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def schema: StructType = sparkSession.withActive {
-    removeInternalMetadata(queryExecution.analyzed.schema)
+    queryExecution.analyzed.schema
   }
 
   /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index a89cae865435..620ee430cab2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -24,6 +24,7 @@ import scala.util.Random
 import org.scalatest.matchers.must.Matchers.the
 
 import org.apache.spark.{SparkArithmeticException, SparkRuntimeException}
+import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS
 import org.apache.spark.sql.execution.WholeStageCodegenExec
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec, SortAggregateExec}
@@ -1464,7 +1465,7 @@ class DataFrameAggregateSuite extends QueryTest
         Duration.ofSeconds(14)) ::
       Nil)
     
assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
-    val metadata = Metadata.empty
+    val metadata = new MetadataBuilder().putString(AUTO_GENERATED_ALIAS, 
"true").build()
     assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, 
false),
       StructField("sum(year-month)", YearMonthIntervalType(), metadata = 
metadata),
       StructField("sum(year)", YearMonthIntervalType(YEAR), metadata = 
metadata),
@@ -1598,7 +1599,7 @@ class DataFrameAggregateSuite extends QueryTest
         Duration.ofMinutes(4).plusSeconds(20),
         Duration.ofSeconds(7)) :: Nil)
     
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
-    val metadata = Metadata.empty
+    val metadata = new MetadataBuilder().putString(AUTO_GENERATED_ALIAS, 
"true").build()
     assert(avgDF2.schema == StructType(Seq(
       StructField("class", IntegerType, false),
       StructField("avg(year-month)", YearMonthIntervalType(), metadata = 
metadata),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to