This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new d393b50e6d5 [SPARK-43380][SQL] Fix slowdown in Avro read
d393b50e6d5 is described below

commit d393b50e6d5e64976747c9e84e3787366dbe4280
Author: zeruibao <zerui....@databricks.com>
AuthorDate: Fri Oct 27 14:45:33 2023 +0800

    [SPARK-43380][SQL] Fix slowdown in Avro read
    
    ### What changes were proposed in this pull request?
    Fix slowdown in Avro read. There is a 
https://github.com/apache/spark/pull/42503 causes the performance regression. 
It seems that creating an `AvroOptions` inside `toSqlType` is very expensive. 
Try to pass this in the callstack.
    
    After regression
    
![image-20231024-193909](https://github.com/apache/spark/assets/125398515/c6af9376-e058-4da9-8f63-d9e8663b36ef)
    
    Before regression
    
![image-20231024-193650](https://github.com/apache/spark/assets/125398515/fd609c05-accb-4ce8-8020-2866328a52f7)
    
    ### Why are the changes needed?
    Need to fix the performance regression of Avro read.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing UT test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43530 from zeruibao/SPARK-4380-real-fix-regression-2.
    
    Lead-authored-by: zeruibao <zerui....@databricks.com>
    Co-authored-by: Zerui Bao <125398515+zerui...@users.noreply.github.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 7d94c5769a8b95a2811e73527fa6ea60f9087901)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../apache/spark/sql/avro/AvroDataToCatalyst.scala |  3 +-
 .../apache/spark/sql/avro/AvroDeserializer.scala   | 11 +++++---
 .../org/apache/spark/sql/avro/AvroFileFormat.scala |  3 +-
 .../apache/spark/sql/avro/SchemaConverters.scala   | 32 ++++++++++++++--------
 .../sql/v2/avro/AvroPartitionReaderFactory.scala   |  3 +-
 .../sql/avro/AvroCatalystDataConversionSuite.scala |  7 +++--
 .../apache/spark/sql/avro/AvroRowReaderSuite.scala |  3 +-
 .../org/apache/spark/sql/avro/AvroSerdeSuite.scala |  3 +-
 .../org/apache/spark/sql/avro/AvroSuite.scala      |  2 +-
 9 files changed, 43 insertions(+), 24 deletions(-)

diff --git 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
index 59f2999bdd3..2c2a45fc3f1 100644
--- 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
+++ 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
@@ -61,7 +61,8 @@ private[sql] case class AvroDataToCatalyst(
   @transient private lazy val reader = new 
GenericDatumReader[Any](actualSchema, expectedSchema)
 
   @transient private lazy val deserializer =
-    new AvroDeserializer(expectedSchema, dataType, 
avroOptions.datetimeRebaseModeInRead)
+    new AvroDeserializer(expectedSchema, dataType,
+      avroOptions.datetimeRebaseModeInRead, 
avroOptions.useStableIdForUnionType)
 
   @transient private var decoder: BinaryDecoder = _
 
diff --git 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index e82116eec1e..fe0bd7392b6 100644
--- 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++ 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -49,18 +49,21 @@ private[sql] class AvroDeserializer(
     rootCatalystType: DataType,
     positionalFieldMatch: Boolean,
     datetimeRebaseSpec: RebaseSpec,
-    filters: StructFilters) {
+    filters: StructFilters,
+    useStableIdForUnionType: Boolean) {
 
   def this(
       rootAvroType: Schema,
       rootCatalystType: DataType,
-      datetimeRebaseMode: String) = {
+      datetimeRebaseMode: String,
+      useStableIdForUnionType: Boolean) = {
     this(
       rootAvroType,
       rootCatalystType,
       positionalFieldMatch = false,
       RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)),
-      new NoopFilters)
+      new NoopFilters,
+      useStableIdForUnionType)
   }
 
   private lazy val decimalConversions = new DecimalConversion()
@@ -118,7 +121,7 @@ private[sql] class AvroDeserializer(
     val incompatibleMsg = errorPrefix +
         s"schema is incompatible (avroType = $avroType, sqlType = 
${catalystType.sql})"
 
-    val realDataType = SchemaConverters.toSqlType(avroType).dataType
+    val realDataType = SchemaConverters.toSqlType(avroType, 
useStableIdForUnionType).dataType
     val confKey = SQLConf.LEGACY_AVRO_ALLOW_INCOMPATIBLE_SCHEMA
     val preventReadingIncorrectType = !SQLConf.get.getConf(confKey)
 
diff --git 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
index 53562a3afdb..7b0292df43c 100755
--- 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
+++ 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
@@ -141,7 +141,8 @@ private[sql] class AvroFileFormat extends FileFormat
             requiredSchema,
             parsedOptions.positionalFieldMatching,
             datetimeRebaseMode,
-            avroFilters)
+            avroFilters,
+            parsedOptions.useStableIdForUnionType)
           override val stopPosition = file.start + file.length
 
           override def hasNext: Boolean = hasNextRow
diff --git 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
index 6f21639e28d..06abe977e3b 100644
--- 
a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
+++ 
b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
@@ -46,16 +46,24 @@ object SchemaConverters {
    */
   case class SchemaType(dataType: DataType, nullable: Boolean)
 
+  /**
+   * Converts an Avro schema to a corresponding Spark SQL schema.
+   *
+   * @since 4.0.0
+   */
+  def toSqlType(avroSchema: Schema, useStableIdForUnionType: Boolean): 
SchemaType = {
+    toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType)
+  }
   /**
    * Converts an Avro schema to a corresponding Spark SQL schema.
    *
    * @since 2.4.0
    */
   def toSqlType(avroSchema: Schema): SchemaType = {
-    toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(Map()))
+    toSqlType(avroSchema, false)
   }
   def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType 
= {
-    toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(options))
+    toSqlTypeHelper(avroSchema, Set.empty, 
AvroOptions(options).useStableIdForUnionType)
   }
 
   // The property specifies Catalyst type of the given field
@@ -64,7 +72,7 @@ object SchemaConverters {
   private def toSqlTypeHelper(
       avroSchema: Schema,
       existingRecordNames: Set[String],
-      avroOptions: AvroOptions): SchemaType = {
+      useStableIdForUnionType: Boolean): SchemaType = {
     avroSchema.getType match {
       case INT => avroSchema.getLogicalType match {
         case _: Date => SchemaType(DateType, nullable = false)
@@ -117,7 +125,7 @@ object SchemaConverters {
         }
         val newRecordNames = existingRecordNames + avroSchema.getFullName
         val fields = avroSchema.getFields.asScala.map { f =>
-          val schemaType = toSqlTypeHelper(f.schema(), newRecordNames, 
avroOptions)
+          val schemaType = toSqlTypeHelper(f.schema(), newRecordNames, 
useStableIdForUnionType)
           StructField(f.name, schemaType.dataType, schemaType.nullable)
         }
 
@@ -127,13 +135,14 @@ object SchemaConverters {
         val schemaType = toSqlTypeHelper(
           avroSchema.getElementType,
           existingRecordNames,
-          avroOptions)
+          useStableIdForUnionType)
         SchemaType(
           ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
           nullable = false)
 
       case MAP =>
-        val schemaType = toSqlTypeHelper(avroSchema.getValueType, 
existingRecordNames, avroOptions)
+        val schemaType = toSqlTypeHelper(avroSchema.getValueType,
+          existingRecordNames, useStableIdForUnionType)
         SchemaType(
           MapType(StringType, schemaType.dataType, valueContainsNull = 
schemaType.nullable),
           nullable = false)
@@ -143,17 +152,18 @@ object SchemaConverters {
           // In case of a union with null, eliminate it and make a recursive 
call
           val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema)
           if (remainingUnionTypes.size == 1) {
-            toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames, 
avroOptions)
+            toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames, 
useStableIdForUnionType)
               .copy(nullable = true)
           } else {
             toSqlTypeHelper(
               Schema.createUnion(remainingUnionTypes.asJava),
               existingRecordNames,
-              avroOptions).copy(nullable = true)
+              useStableIdForUnionType).copy(nullable = true)
           }
         } else avroSchema.getTypes.asScala.map(_.getType).toSeq match {
           case Seq(t1) =>
-            toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames, 
avroOptions)
+            toSqlTypeHelper(avroSchema.getTypes.get(0),
+              existingRecordNames, useStableIdForUnionType)
           case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
             SchemaType(LongType, nullable = false)
           case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
@@ -167,9 +177,9 @@ object SchemaConverters {
             val fieldNameSet : mutable.Set[String] = mutable.Set()
             val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
               case (s, i) =>
-                val schemaType = toSqlTypeHelper(s, existingRecordNames, 
avroOptions)
+                val schemaType = toSqlTypeHelper(s, existingRecordNames, 
useStableIdForUnionType)
 
-                val fieldName = if (avroOptions.useStableIdForUnionType) {
+                val fieldName = if (useStableIdForUnionType) {
                   // Avro's field name may be case sensitive, so field names 
for two named type
                   // could be "a" and "A" and we need to distinguish them. In 
this case, we throw
                   // an exception.
diff --git 
a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
 
b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
index cc7bd180e84..2c85c1b0673 100644
--- 
a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
+++ 
b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
@@ -103,7 +103,8 @@ case class AvroPartitionReaderFactory(
           readDataSchema,
           options.positionalFieldMatching,
           datetimeRebaseMode,
-          avroFilters)
+          avroFilters,
+          options.useStableIdForUnionType)
         override val stopPosition = partitionedFile.start + 
partitionedFile.length
 
         override def next(): Boolean = hasNextRow
diff --git 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala
 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala
index 1cb34a0bc4d..250b5e0615a 100644
--- 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala
+++ 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala
@@ -59,7 +59,7 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite
 
     val expected = {
       val avroSchema = new Schema.Parser().parse(schema)
-      SchemaConverters.toSqlType(avroSchema).dataType match {
+      SchemaConverters.toSqlType(avroSchema, false).dataType match {
         case st: StructType => Row.fromSeq((0 until st.length).map(_ => null))
         case _ => null
       }
@@ -281,13 +281,14 @@ class AvroCatalystDataConversionSuite extends 
SparkFunSuite
       data: GenericData.Record,
       expected: Option[Any],
       filters: StructFilters = new NoopFilters): Unit = {
-    val dataType = SchemaConverters.toSqlType(schema).dataType
+    val dataType = SchemaConverters.toSqlType(schema, false).dataType
     val deserializer = new AvroDeserializer(
       schema,
       dataType,
       false,
       RebaseSpec(LegacyBehaviorPolicy.CORRECTED),
-      filters)
+      filters,
+      false)
     val deserialized = deserializer.deserialize(data)
     expected match {
       case None => assert(deserialized == None)
diff --git 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
index 70d0bc6c0ad..965e3a0c1cb 100644
--- 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
+++ 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
@@ -75,7 +75,8 @@ class AvroRowReaderSuite
           StructType(new StructField("value", IntegerType, true) :: Nil),
           false,
           RebaseSpec(CORRECTED),
-          new NoopFilters)
+          new NoopFilters,
+          false)
         override val stopPosition = fileSize
 
         override def hasNext: Boolean = hasNextRow
diff --git 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala
index 7f99f3c737c..a21f3f008fd 100644
--- 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala
+++ 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala
@@ -226,7 +226,8 @@ object AvroSerdeSuite {
         sql,
         isPositional(matchType),
         RebaseSpec(CORRECTED),
-        new NoopFilters)
+        new NoopFilters,
+        false)
   }
 
   /**
diff --git 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index ffb0a49641b..1df99210a55 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -2137,7 +2137,7 @@ abstract class AvroSuite
 
   private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = {
     val message = intercept[IncompatibleSchemaException] {
-      SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema))
+      SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema), false)
     }.getMessage
 
     assert(message.contains("Found recursive reference in Avro schema"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to