This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1cbc424ae2a [SPARK-45346][SQL] Parquet schema inference should respect 
case sensitive flag when merging schema
1cbc424ae2a is described below

commit 1cbc424ae2acaf4d82f928cfea2767c81425305e
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Wed Sep 27 16:00:11 2023 +0800

    [SPARK-45346][SQL] Parquet schema inference should respect case sensitive 
flag when merging schema
    
    ### What changes were proposed in this pull request?
    
    Currently, when we infer schema from parquet files and try to merge the 
schema, it's always case-sensitive. Then a check fails later which tries to 
make sure the data schema of parquet fields does not have duplicated columns, 
in a case-insensitive way (the default).
    
    This PR fixes the problem and make the schema merging respect the case 
sensitivity flag.
    
    ### Why are the changes needed?
    
    bug fix
    
    ### Does this PR introduce _any_ user-facing change?
    
    yes, spark can read some parquet files now.
    
    ### How was this patch tested?
    
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43134 from cloud-fan/merge-schema.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/sql/types/StructType.scala    | 28 +++++++++++++++-------
 .../execution/datasources/SchemaMergeUtils.scala   |  5 ++--
 .../datasources/parquet/ParquetSchemaSuite.scala   | 21 ++++++++++++++++
 3 files changed, 43 insertions(+), 11 deletions(-)

diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 8edc7cf370b..8fd7f47b346 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.types
 
+import java.util.Locale
+
 import scala.collection.{mutable, Map}
 import scala.util.Try
 import scala.util.control.NonFatal
@@ -476,8 +478,8 @@ case class StructType(fields: Array[StructField]) extends 
DataType with Seq[Stru
    * 4. Otherwise, `this` and `that` are considered as conflicting schemas and 
an exception would be
    *    thrown.
    */
-  private[sql] def merge(that: StructType): StructType =
-    StructType.merge(this, that).asInstanceOf[StructType]
+  private[sql] def merge(that: StructType, caseSensitive: Boolean = true): 
StructType =
+    StructType.merge(this, that, caseSensitive).asInstanceOf[StructType]
 
   override private[spark] def asNullable: StructType = {
     val newFields = fields.map {
@@ -561,16 +563,20 @@ object StructType extends AbstractDataType {
       StructType(newFields)
     })
 
-  private[sql] def merge(left: DataType, right: DataType): DataType =
+  private[sql] def merge(left: DataType, right: DataType, caseSensitive: 
Boolean = true): DataType =
     mergeInternal(left, right, (s1: StructType, s2: StructType) => {
       val leftFields = s1.fields
       val rightFields = s2.fields
       val newFields = mutable.ArrayBuffer.empty[StructField]
 
-      val rightMapped = fieldsMap(rightFields)
+      def normalize(name: String): String = {
+        if (caseSensitive) name else name.toLowerCase(Locale.ROOT)
+      }
+
+      val rightMapped = fieldsMap(rightFields, caseSensitive)
       leftFields.foreach {
         case leftField @ StructField(leftName, leftType, leftNullable, _) =>
-          rightMapped.get(leftName)
+          rightMapped.get(normalize(leftName))
             .map { case rightField @ StructField(rightName, rightType, 
rightNullable, _) =>
               try {
                 leftField.copy(
@@ -588,9 +594,9 @@ object StructType extends AbstractDataType {
             .foreach(newFields += _)
       }
 
-      val leftMapped = fieldsMap(leftFields)
+      val leftMapped = fieldsMap(leftFields, caseSensitive)
       rightFields
-        .filterNot(f => leftMapped.get(f.name).nonEmpty)
+        .filterNot(f => leftMapped.contains(normalize(f.name)))
         .foreach { f =>
           newFields += f
         }
@@ -643,11 +649,15 @@ object StructType extends AbstractDataType {
         throw DataTypeErrors.cannotMergeIncompatibleDataTypesError(left, right)
     }
 
-  private[sql] def fieldsMap(fields: Array[StructField]): Map[String, 
StructField] = {
+  private[sql] def fieldsMap(
+      fields: Array[StructField],
+      caseSensitive: Boolean = true): Map[String, StructField] = {
     // Mimics the optimization of breakOut, not present in Scala 2.13, while 
working in 2.12
     val map = mutable.Map[String, StructField]()
     map.sizeHint(fields.length)
-    fields.foreach(s => map.put(s.name, s))
+    fields.foreach { s =>
+      if (caseSensitive) map.put(s.name, s) else 
map.put(s.name.toLowerCase(Locale.ROOT), s)
+    }
     map
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala
index 35d9b5d6034..cf0e67ecc30 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala
@@ -64,6 +64,7 @@ object SchemaMergeUtils extends Logging {
 
     val ignoreCorruptFiles =
       new FileSourceOptions(CaseInsensitiveMap(parameters)).ignoreCorruptFiles
+    val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
 
     // Issues a Spark job to read Parquet/ORC schema in parallel.
     val partiallyMergedSchemas =
@@ -84,7 +85,7 @@ object SchemaMergeUtils extends Logging {
             var mergedSchema = schemas.head
             schemas.tail.foreach { schema =>
               try {
-                mergedSchema = mergedSchema.merge(schema)
+                mergedSchema = mergedSchema.merge(schema, caseSensitive)
               } catch { case cause: SparkException =>
                 throw 
QueryExecutionErrors.failedMergingSchemaError(mergedSchema, schema, cause)
               }
@@ -99,7 +100,7 @@ object SchemaMergeUtils extends Logging {
       var finalSchema = partiallyMergedSchemas.head
       partiallyMergedSchemas.tail.foreach { schema =>
         try {
-          finalSchema = finalSchema.merge(schema)
+          finalSchema = finalSchema.merge(schema, caseSensitive)
         } catch { case cause: SparkException =>
           throw QueryExecutionErrors.failedMergingSchemaError(finalSchema, 
schema, cause)
         }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
index 30f46a3cac2..facc9b90ff7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
@@ -996,6 +996,27 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
     }
   }
 
+  test("SPARK-45346: merge schema should respect case sensitivity") {
+    import testImplicits._
+    Seq(true, false).foreach { caseSensitive =>
+      withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
+        withTempPath { path =>
+          
Seq(1).toDF("col").write.mode("append").parquet(path.getCanonicalPath)
+          
Seq(2).toDF("COL").write.mode("append").parquet(path.getCanonicalPath)
+          val df = spark.read.option("mergeSchema", 
"true").parquet(path.getCanonicalPath)
+          if (caseSensitive) {
+            assert(df.columns.toSeq.sorted == Seq("COL", "col"))
+            assert(df.collect().length == 2)
+          } else {
+            // The final column name depends on which file is listed first, 
and is a bit random.
+            assert(df.columns.toSeq.map(_.toLowerCase(java.util.Locale.ROOT)) 
== Seq("col"))
+            assert(df.collect().length == 2)
+          }
+        }
+      }
+    }
+  }
+
   // =======================================
   // Tests for parquet schema mismatch error
   // =======================================


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to