This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new fa2e9c7275aa [SPARK-47739][SQL] Register logical avro type fa2e9c7275aa is described below commit fa2e9c7275aa1c09652d0df0992565c32974b2b9 Author: milastdbx <milan.stefano...@databricks.com> AuthorDate: Tue Apr 16 03:38:19 2024 -0700 [SPARK-47739][SQL] Register logical avro type ### What changes were proposed in this pull request? In this pull request I propose that we register logical avro types when we initialize `AvroUtils` and `AvroFileFormat`, otherwise for first schema discovery we might get wrong result on very first execution after spark starts. <img width="1727" alt="image" src="https://github.com/apache/spark/assets/150366084/3eaba6e3-34ec-4ca9-ae89-d0259ce942ba"> example ```scala val new_schema = """ | { | "type": "record", | "name": "Entry", | "fields": [ | { | "name": "rate", | "type": [ | "null", | { | "type": "long", | "logicalType": "custom-decimal", | "precision": 38, | "scale": 9 | } | ], | "default": null | } | ] | }""".stripMargin spark.read.format("avro").option("avroSchema", new_schema).load().printSchema // maps to long - WRONG spark.read.format("avro").option("avroSchema", new_schema).load().printSchema // maps to Decimal - CORRECT ``` ### Why are the changes needed? To fix issue with resolving avro schema upon spark startup. ### Does this PR introduce _any_ user-facing change? No, its a bugfix ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #45895 from milastdbx/dev/milast/fixAvroLogicalTypeRegistration. Lead-authored-by: milastdbx <milan.stefano...@databricks.com> Co-authored-by: Dongjoon Hyun <dongj...@apache.org> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../org/apache/spark/sql/avro/AvroFileFormat.scala | 21 ++++-- .../spark/sql/avro/AvroLogicalTypeInitSuite.scala | 76 ++++++++++++++++++++++ 2 files changed, 91 insertions(+), 6 deletions(-) diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 2792edaea284..372f24b54f5c 100755 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -43,6 +43,8 @@ import org.apache.spark.util.SerializableConfiguration private[sql] class AvroFileFormat extends FileFormat with DataSourceRegister with Logging with Serializable { + AvroFileFormat.registerCustomAvroTypes() + override def equals(other: Any): Boolean = other match { case _: AvroFileFormat => true case _ => false @@ -173,10 +175,17 @@ private[sql] class AvroFileFormat extends FileFormat private[avro] object AvroFileFormat { val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension" - // Register the customized decimal type backed by long. - LogicalTypes.register(CustomDecimal.TYPE_NAME, new LogicalTypes.LogicalTypeFactory { - override def fromSchema(schema: Schema): LogicalType = { - new CustomDecimal(schema) - } - }) + /** + * Register Spark defined custom Avro types. + */ + def registerCustomAvroTypes(): Unit = { + // Register the customized decimal type backed by long. + LogicalTypes.register(CustomDecimal.TYPE_NAME, new LogicalTypes.LogicalTypeFactory { + override def fromSchema(schema: Schema): LogicalType = { + new CustomDecimal(schema) + } + }) + } + + registerCustomAvroTypes() } diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeInitSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeInitSuite.scala new file mode 100644 index 000000000000..126440ed69b8 --- /dev/null +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeInitSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.avro + +import org.apache.spark.SparkConf +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.DecimalType + +/** + * Test suite for Avro logical type initialization. + * Tests here must run in isolation, otherwise some other test might + * initialize variable and make this test flaky + */ +abstract class AvroLogicalTypeInitSuite + extends QueryTest + with SharedSparkSession { + + test("SPARK-47739: custom logical type registration test") { + val avroTypeJson = + """ + |{ + | "type": "record", + | "name": "Entry", + | "fields": [ + | { + | "name": "test_col", + | "type": [ + | "null", + | { + | "type": "long", + | "logicalType": "custom-decimal", + | "precision": 38, + | "scale": 9 + | } + | ], + | "default": null + | } + | ] + |} + | + """.stripMargin + + val df = spark.read.format("avro").option("avroSchema", avroTypeJson).load() + assert(df.schema.fields(0).dataType == DecimalType(38, 9)) + } +} + +class AvroV1LogicalTypeInitSuite extends AvroLogicalTypeInitSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "avro") +} + +class AvroV2LogicalTypeInitSuite extends AvroLogicalTypeInitSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "") +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org