This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new cfbed998530e Revert "[SPARK-48322][SPARK-42965][SQL][CONNECT][PYTHON] Drop internal metadata in `DataFrame.schema`" cfbed998530e is described below commit cfbed998530efaaf17f36d99a9462376eaa7d2ad Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed May 29 20:44:36 2024 +0800 Revert "[SPARK-48322][SPARK-42965][SQL][CONNECT][PYTHON] Drop internal metadata in `DataFrame.schema`" revert https://github.com/apache/spark/pull/46636 https://github.com/apache/spark/pull/46636#issuecomment-2137321359 Closes #46790 from zhengruifeng/revert_metadata_drop. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/pandas/internal.py | 37 +++++++++++++++++----- .../sql/tests/connect/test_connect_function.py | 4 ++- python/pyspark/sql/types.py | 13 ++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 4 +-- .../apache/spark/sql/DataFrameAggregateSuite.scala | 5 +-- 5 files changed, 50 insertions(+), 13 deletions(-) diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index fd0f28e50b2f..04285aa2d879 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -33,6 +33,7 @@ from pyspark.sql import ( Window, ) from pyspark.sql.types import ( # noqa: F401 + _drop_metadata, BooleanType, DataType, LongType, @@ -756,10 +757,20 @@ class InternalFrame: if is_testing(): struct_fields = spark_frame.select(index_spark_columns).schema.fields - assert all( - index_field.struct_field == struct_field - for index_field, struct_field in zip(index_fields, struct_fields) - ), (index_fields, struct_fields) + if is_remote(): + # TODO(SPARK-42965): For some reason, the metadata of StructField is different + # in a few tests when using Spark Connect. However, the function works properly. + # Therefore, we temporarily perform Spark Connect tests by excluding metadata + # until the issue is resolved. + assert all( + _drop_metadata(index_field.struct_field) == _drop_metadata(struct_field) + for index_field, struct_field in zip(index_fields, struct_fields) + ), (index_fields, struct_fields) + else: + assert all( + index_field.struct_field == struct_field + for index_field, struct_field in zip(index_fields, struct_fields) + ), (index_fields, struct_fields) self._index_fields: List[InternalField] = index_fields @@ -774,10 +785,20 @@ class InternalFrame: if is_testing(): struct_fields = spark_frame.select(data_spark_columns).schema.fields - assert all( - data_field.struct_field == struct_field - for data_field, struct_field in zip(data_fields, struct_fields) - ), (data_fields, struct_fields) + if is_remote(): + # TODO(SPARK-42965): For some reason, the metadata of StructField is different + # in a few tests when using Spark Connect. However, the function works properly. + # Therefore, we temporarily perform Spark Connect tests by excluding metadata + # until the issue is resolved. + assert all( + _drop_metadata(data_field.struct_field) == _drop_metadata(struct_field) + for data_field, struct_field in zip(data_fields, struct_fields) + ), (data_fields, struct_fields) + else: + assert all( + data_field.struct_field == struct_field + for data_field, struct_field in zip(data_fields, struct_fields) + ), (data_fields, struct_fields) self._data_fields: List[InternalField] = data_fields diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index 1fb0195b5203..0f0abfd4b856 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -22,6 +22,7 @@ from pyspark.util import is_remote_only from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql import SparkSession as PySparkSession from pyspark.sql.types import ( + _drop_metadata, StringType, StructType, StructField, @@ -1673,7 +1674,8 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, S ) ) - self.assertEqual(cdf.schema, sdf.schema) + # TODO: 'cdf.schema' has an extra metadata '{'__autoGeneratedAlias': 'true'}' + self.assertEqual(_drop_metadata(cdf.schema), _drop_metadata(sdf.schema)) self.assertEqual(cdf.collect(), sdf.collect()) def test_csv_functions(self): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 62f09e948792..c72ff72ce426 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1770,6 +1770,19 @@ _INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?") _COLLATIONS_METADATA_KEY = "__COLLATIONS" +def _drop_metadata(d: Union[DataType, StructField]) -> Union[DataType, StructField]: + assert isinstance(d, (DataType, StructField)) + if isinstance(d, StructField): + return StructField(d.name, _drop_metadata(d.dataType), d.nullable, None) + elif isinstance(d, StructType): + return StructType([cast(StructField, _drop_metadata(f)) for f in d.fields]) + elif isinstance(d, ArrayType): + return ArrayType(_drop_metadata(d.elementType), d.containsNull) + elif isinstance(d, MapType): + return MapType(_drop_metadata(d.keyType), _drop_metadata(d.valueType), d.valueContainsNull) + return d + + def _parse_datatype_string(s: str) -> DataType: """ Parses the given data type string to a :class:`DataType`. The data type string format equals diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index afde54fc3d11..c7511737b2b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes -import org.apache.spark.sql.catalyst.util.{removeInternalMetadata, CharVarcharUtils, IntervalUtils} +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution._ @@ -561,7 +561,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def schema: StructType = sparkSession.withActive { - removeInternalMetadata(queryExecution.analyzed.schema) + queryExecution.analyzed.schema } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index a89cae865435..620ee430cab2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} +import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -1464,7 +1465,7 @@ class DataFrameAggregateSuite extends QueryTest Duration.ofSeconds(14)) :: Nil) assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) - val metadata = Metadata.empty + val metadata = new MetadataBuilder().putString(AUTO_GENERATED_ALIAS, "true").build() assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false), StructField("sum(year-month)", YearMonthIntervalType(), metadata = metadata), StructField("sum(year)", YearMonthIntervalType(YEAR), metadata = metadata), @@ -1598,7 +1599,7 @@ class DataFrameAggregateSuite extends QueryTest Duration.ofMinutes(4).plusSeconds(20), Duration.ofSeconds(7)) :: Nil) assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) - val metadata = Metadata.empty + val metadata = new MetadataBuilder().putString(AUTO_GENERATED_ALIAS, "true").build() assert(avgDF2.schema == StructType(Seq( StructField("class", IntegerType, false), StructField("avg(year-month)", YearMonthIntervalType(), metadata = metadata), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org