This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b9c0e9331935 [SPARK-47309][SQL] XML: Add schema inference tests for value tags b9c0e9331935 is described below commit b9c0e93319350caa4beecdcd42051449ec1f9c08 Author: Shujing Yang <shujing.y...@databricks.com> AuthorDate: Wed Mar 20 10:16:43 2024 +0900 [SPARK-47309][SQL] XML: Add schema inference tests for value tags ### What changes were proposed in this pull request? Add schema inference tags for corrupt records, null values and value tags. For value tags, this PR adds the following tests: 1. Conflict between primitive types conflict 2. Root-level value tag 3. empty value tag in some rows 4. array of value tags: 1) values split into multiple lines 2) interspersed in nested structs: empty fields and optional fields in structs 3) interspersed in arrays and value tags: empty fields and optional fields in structs 4) name conflict 5) CDATA and comments 6) no spaces / some spaces / whitespaces between valueTags and elements ### Why are the changes needed? This is a test-only change. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? This is a test-only change. ### Was this patch authored or co-authored using generative AI tooling? No Closes #45538 from shujingyang-db/xml-inference-test. Lead-authored-by: Shujing Yang <shujing.y...@databricks.com> Co-authored-by: Shujing Yang <135740748+shujingyang...@users.noreply.github.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../spark/sql/catalyst/xml/XmlInferSchema.scala | 8 +- .../execution/datasources/xml/TestXmlData.scala | 269 ++++++++++++++++ .../datasources/xml/XmlInferSchemaSuite.scala | 338 ++++++++++++++++++++- 3 files changed, 613 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index b9342c53d020..4640f86d5997 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.util.{DateFormatter, DropMalformedMode, FailFastMode, ParseMode, PermissiveMode, TimestampFormatter} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT -import org.apache.spark.sql.catalyst.xml.XmlInferSchema.compatibleType import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ @@ -46,6 +45,8 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) extends Serializable with Logging { + import org.apache.spark.sql.catalyst.xml.XmlInferSchema._ + private val decimalParser = ExprUtils.getDecimalParser(options.locale) private val timestampFormatter = TimestampFormatter( @@ -120,6 +121,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep + // XML shouldn't run into this line StructType(Seq()) } } @@ -541,6 +543,10 @@ object XmlInferSchema { // As this library can infer an element with attributes as StructType whereas // some can be inferred as other non-structural data types, this case should be // treated. + // 1. Without value tags, combining structs and primitive types defaults to string type + // 2. With value tags, combining structs and primitive types defaults to + // a struct with value tags of compatible type + // This behavior keeps aligned with JSON case (st: StructType, dt: DataType) if st.fieldNames.contains(valueTag) => val valueIndex = st.fieldNames.indexOf(valueTag) val valueField = st.fields(valueIndex) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala index 704a02482ada..616ccda62fc2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala @@ -361,6 +361,52 @@ private[xml] trait TestXmlData { |</ROW> |""".stripMargin :: Nil + def nullsInArrays: Seq[String] = + """<ROW> + <field1> + <array1> + <array2>value1</array2> + <array2>value2</array2> + </array1> + <array1/> + </field1> + <field1/> + </ROW>""" :: + """ + <ROW> + <field2/> + <field2> + <array1> + <Test>1</Test> + </array1> + <array1/> + </field2> + </ROW>""" :: + """ + <ROW> + <field1/> + <field1><array1/></field1> + <field2/> + </ROW>""" :: Nil + + def corruptRecords: Seq[String] = + """<ROW>""" :: + """""" :: + """<ROW> + | <a>1</a> + | <b>2</b> + |</ROW>""".stripMargin :: + """ + |<ROW> + | <a>str_a_4</a> + | <b>str_b_4</b> + | <c>str_c_4</c> + |</ROW> + |""".stripMargin :: + """ + |</ROW> + |""".stripMargin :: Nil + def emptyRecords: Seq[String] = """<ROW> <a><struct></struct></a> @@ -378,4 +424,227 @@ private[xml] trait TestXmlData { <item/> </b> </ROW>""" :: Nil + + def arrayAndStructRecords: Seq[String] = + """<ROW> + <a> + <b>1</b> + </a> + </ROW>""" :: + """<ROW> + <a><item/><item/></a> + </ROW>""" :: + Nil + + def valueTagsTypeConflict: Seq[String] = + """ + |<ROW> + | 13.1 + | <a> + | 11 + | <b> + | true + | <c>1</c> + | </b> + | </a> + | string + |</ROW> + |""".stripMargin :: + """ + |<ROW> + | string + | <a> + | 21474836470 + | <b> + | false + | <c>2</c> + | </b> + | </a> + | true + |</ROW> + |""".stripMargin :: + """ + |<ROW> + |<a> + | <b> + | 12 + | <c>3</c> + | </b> + |</a> + |92233720368547758070 + |</ROW> + |""".stripMargin :: Nil + + val emptyValueTags: Seq[String] = + """ + |<ROW> + | str1 + | <a> <b>1</b> + | </a>str2 + |</ROW> + |""".stripMargin :: + """<ROW> <a><b/> value</a></ROW>""" :: + """<ROW><a><b>3</b> </a> </ROW>""" :: + """<ROW><a><b>4</b> </a> + | str3 + |</ROW>""".stripMargin :: Nil + + val multilineValueTags = + """ + |<ROW> + | value1 + | <a>1</a> + | value2 + |</ROW> + |""".stripMargin :: + """ + |<ROW> + | value3 + | value4<a>1</a> + |</ROW> + |""".stripMargin :: Nil + + val valueTagsAroundStructs = + """ + |<ROW> + | value1 + | <a> + | value2 + | <b> + | 3 + | <c>1</c> + | </b> + | value4 + | </a> + | value5 + |</ROW> + |""".stripMargin :: + """ + |<ROW> + | value1 + | <a> + | value2 + | <b>3</b> + | value4 + | </a> + |</ROW> + |""".stripMargin :: + """ + |<ROW> + | <a> + | <b></b> + | value4 + | <!--First comment--> + | value5 + | </a> + | value6 + |</ROW> + |""".stripMargin :: + """ + |<ROW> + | value1 + | <a> + | value2 + | <b> + | 3 + | <c/> + | </b> + | value4 + | </a> + | value5 + |</ROW> + |""".stripMargin :: Nil + + val valueTagsAroundArrays = + """ + |<ROW> + | value1 + | <array1> + | value2 + | <array2> + | 1 + | <num>1</num> + | 2 + | </array2> + | value3 + | <!--First comment--> <!--Second comment--> + | value4<!--Third comment--> + | value5 + | <array2>2</array2>value6 + | value7 + | </array1> + | value8 + | <array1> + | value9 + | <array2> <!--First comment--><num>2</num></array2> + | value10 + | <array2></array2> + | <array2> <!--First comment--> + | <!--Second comment--></array2> + | <array2>3</array2> + | value11 + | </array1> + | value12 + | <!--First comment--> + | value13 + |</ROW> + |""".stripMargin :: + """ + |<ROW> + | <array1> + | value1 + | </array1> + |</ROW> + |""".stripMargin :: + """ + |<ROW> + | <array1> + | <array2> + | 1 + | </array2> + | </array1> + | value1 + |</ROW> + |""".stripMargin :: Nil + + val valueTagConflictName = + """<ROW> + | <a>1</a> + | 2 + |</ROW>""".stripMargin :: Nil + + val valueTagWithComments = + """ + |<ROW> + | <!--First comment--> + | <!--Second comment--> + | <a><!--First comment--></a> + | <a attr="1"><!--First comment--> <!--Second comment--></a> + | 2 + |</ROW> + |""".stripMargin :: Nil + + val valueTagWithCDATA = + """ + |<ROW> + | <![CDATA[This is a CDATA section containing <sample1> text.]]> + | <a> + | <![CDATA[This is a CDATA section containing <sample2> text.]]> + | <![CDATA[This is a CDATA section containing <sample3> text.]]> + | <b>1</b> + | <![CDATA[This is a CDATA section containing <sample4> text.]]> + | + | <b>2</b> + | + | </a> + | + |</ROW> + |""".stripMargin :: Nil + + val valueTagIsNullValue = + """ + |<ROW> + | 1 + |</ROW> + |""".stripMargin :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala index 697bd3d8b824..286120ff40b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala @@ -16,13 +16,19 @@ */ package org.apache.spark.sql.execution.datasources.xml +import java.io.File + +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Encoders, QueryTest, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{ ArrayType, BooleanType, DecimalType, DoubleType, + IntegerType, LongType, StringType, StructField, @@ -31,7 +37,13 @@ import org.apache.spark.sql.types.{ class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXmlData { - val baseOptions = Map("rowTag" -> "ROW") + private val baseOptions = Map("rowTag" -> "ROW") + + private val ignoreSurroundingSpacesOptions = Map("ignoreSurroundingSpaces" -> "true") + + private val notIgnoreSurroundingSpacesOptions = Map("ignoreSurroundingSpaces" -> "false") + + private val valueTagName = "_VALUE" def readData(xmlString: Seq[String], options: Map[String, String] = Map.empty): DataFrame = { val dataset = spark.createDataset(spark.sparkContext.parallelize(xmlString))(Encoders.STRING) @@ -293,4 +305,328 @@ class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXml assert(emptyDF.schema === expectedSchema) } + test("nulls in arrays") { + val expectedSchema = StructType( + StructField( + "field1", + ArrayType( + new StructType() + .add("array1", ArrayType(new StructType().add("array2", ArrayType(StringType)))) + ) + ) :: + StructField( + "field2", + ArrayType( + new StructType() + .add("array1", ArrayType(StructType(StructField("Test", LongType) :: Nil))) + ) + ) :: Nil + ) + val expectedAns = Seq( + Row(Seq(Row(Seq(Row(Seq("value1", "value2")), Row(null))), Row(null)), null), + Row(null, Seq(Row(null), Row(Seq(Row(1), Row(null))))), + Row(Seq(Row(null), Row(Seq(Row(null)))), Seq(Row(null))) + ) + val xmlDF = readData(nullsInArrays) + assert(xmlDF.schema === expectedSchema) + checkAnswer(xmlDF, expectedAns) + } + + test("corrupt records: fail fast mode") { + // fail fast mode is covered in the testcase: DSL test for failing fast in XmlSuite + val schemaOne = StructType( + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil + ) + // `DROPMALFORMED` mode should skip corrupt records + val xmlDFOne = readData(corruptRecords, Map("mode" -> "DROPMALFORMED")) + checkAnswer( + xmlDFOne, + Row("1", "2", null) :: + Row("str_a_4", "str_b_4", "str_c_4") :: Nil + ) + assert(xmlDFOne.schema === schemaOne) + } + + test("turn non-nullable schema into a nullable schema") { + // XML field is missing. + val missingFieldInput = """<ROW><c1>1</c1></ROW>""" + val missingFieldInputDS = + spark.createDataset(spark.sparkContext.parallelize(missingFieldInput :: Nil))(Encoders.STRING) + // XML filed is null. + val nullValueInput = """<ROW><c1>1</c1><c2/></ROW>""" + val nullValueInputDS = + spark.createDataset(spark.sparkContext.parallelize(nullValueInput :: Nil))(Encoders.STRING) + + val schema = StructType( + Seq( + StructField("c1", IntegerType, nullable = false), + StructField("c2", IntegerType, nullable = false) + ) + ) + val expected = schema.asNullable + + Seq(missingFieldInputDS, nullValueInputDS).foreach { xmlStringDS => + Seq("DROPMALFORMED", "FAILFAST", "PERMISSIVE").foreach { mode => + val df = spark.read + .option("mode", mode) + .option("rowTag", "ROW") + .schema(schema) + .xml(xmlStringDS) + assert(df.schema == expected) + checkAnswer(df, Row(1, null) :: Nil) + } + withSQLConf(SQLConf.LEGACY_RESPECT_NULLABILITY_IN_TEXT_DATASET_CONVERSION.key -> "true") { + checkAnswer( + spark.read + .schema( + StructType( + StructField("c1", LongType, nullable = false) :: + StructField("c2", LongType, nullable = false) :: Nil + ) + ) + .option("rowTag", "ROW") + .option("mode", "DROPMALFORMED") + .xml(xmlStringDS), + // It is for testing legacy configuration. This is technically a bug as + // `0` has to be `null` but the schema is non-nullable. + Row(1, 0) + ) + } + } + } + + test("XML with partitions") { + def makePartition(rdd: RDD[String], parent: File, partName: String, partValue: Any): File = { + val p = new File(parent, s"$partName=${partValue.toString}") + rdd.saveAsTextFile(p.getCanonicalPath) + p + } + + withTempPath(root => { + withTempView("test_myxml_with_part") { + val d1 = new File(root, "d1=1") + // root/d1=1/col1=abc + makePartition( + sparkContext.parallelize(2 to 5).map(i => s"""<ROW><a>1</a><b>str$i</b></ROW>"""), + d1, + "col1", + "abc" + ) + + // root/d1=1/col1=abd + makePartition( + sparkContext.parallelize(6 to 10).map(i => s"""<ROW><a>1</a><c>str$i</c></ROW>"""), + d1, + "col1", + "abd" + ) + val expectedSchema = new StructType() + .add("a", LongType) + .add("b", StringType) + .add("c", StringType) + .add("d1", IntegerType) + .add("col1", StringType) + + val df = spark.read.option("rowTag", "ROW").xml(root.getAbsolutePath) + assert(df.schema === expectedSchema) + assert(df.where(col("d1") === 1).where(col("col1") === "abc").select("a").count() == 4) + assert(df.where(col("d1") === 1).where(col("col1") === "abd").select("a").count() == 5) + assert(df.where(col("d1") === 1).select("a").count() == 9) + } + }) + } + + test("value tag - type conflict and root level value tags") { + val xmlDF = readData(valueTagsTypeConflict, ignoreSurroundingSpacesOptions) + val expectedSchema = new StructType() + .add(valueTagName, ArrayType(StringType)) + .add( + "a", + new StructType() + .add(valueTagName, LongType) + .add("b", new StructType().add(valueTagName, StringType).add("c", LongType)) + ) + assert(xmlDF.schema == expectedSchema) + val expectedAns = Seq( + Row(Seq("13.1", "string"), Row(11, Row("true", 1))), + Row(Seq("string", "true"), Row(21474836470L, Row("false", 2))), + Row(Seq("92233720368547758070"), Row(null, Row("12", 3))) + ) + checkAnswer(xmlDF, expectedAns) + } + + test("value tag - spaces and empty values") { + val expectedSchema = new StructType() + .add(valueTagName, ArrayType(StringType)) + .add("a", new StructType().add(valueTagName, StringType).add("b", LongType)) + // even though we don't ignore the surrounding spaces of characters, + // we won't put whitespaces as value tags :) + val xmlDFWSpaces = + readData(emptyValueTags, notIgnoreSurroundingSpacesOptions) + val xmlDFWOSpaces = readData(emptyValueTags, ignoreSurroundingSpacesOptions) + assert(xmlDFWSpaces.schema == expectedSchema) + assert(xmlDFWOSpaces.schema == expectedSchema) + + val expectedAnsWSpaces = Seq( + Row(Seq("\n str1\n ", "str2\n"), Row(null, 1)), + Row(null, Row(" value", null)), + Row(null, Row(null, 3)), + Row(Seq("\n str3\n"), Row(null, 4)) + ) + checkAnswer(xmlDFWSpaces, expectedAnsWSpaces) + val expectedAnsWOSpaces = Seq( + Row(Seq("str1", "str2"), Row(null, 1)), + Row(null, Row("value", null)), + Row(null, Row(null, 3)), + Row(Seq("str3"), Row(null, 4)) + ) + checkAnswer(xmlDFWOSpaces, expectedAnsWOSpaces) + } + + test("value tags - multiple lines") { + val xmlDF = readData(multilineValueTags, ignoreSurroundingSpacesOptions) + val expectedSchema = + new StructType().add(valueTagName, ArrayType(StringType)).add("a", LongType) + val expectedAns = Seq( + Row(Seq("value1", "value2"), 1), + Row(Seq("value3\n value4"), 1) + ) + assert(xmlDF.schema == expectedSchema) + checkAnswer(xmlDF, expectedAns) + } + + test("value tags - around structs") { + val xmlDF = readData(valueTagsAroundStructs) + val expectedSchema = new StructType() + .add(valueTagName, ArrayType(StringType)) + .add( + "a", + new StructType() + .add(valueTagName, ArrayType(StringType)) + .add("b", new StructType().add(valueTagName, LongType).add("c", LongType)) + ) + + assert(xmlDF.schema == expectedSchema) + val expectedAns = Seq( + Row( + Seq("value1", "value5"), + Row(Seq("value2", "value4"), Row(3, 1)) + ), + Row( + Seq("value6"), + Row(Seq("value4", "value5"), Row(null, null)) + ), + Row( + Seq("value1", "value5"), + Row(Seq("value2", "value4"), Row(3, null)) + ), + Row( + Seq("value1"), + Row(Seq("value2", "value4"), Row(3, null)) + ) + ) + checkAnswer(xmlDF, expectedAns) + } + + test("value tags - around arrays") { + val xmlDF = readData(valueTagsAroundArrays) + val expectedSchema = new StructType() + .add(valueTagName, ArrayType(StringType)) + .add( + "array1", + ArrayType( + new StructType() + .add(valueTagName, ArrayType(StringType)) + .add( + "array2", + ArrayType(new StructType() + // The value tag is not of long type due to: + // 1. when we infer the type for the array2 in the second array1, + // it combines a struct type and a primitive type and results in a string type + // 2. when we merge the inferred type for the first array2 and the second, + // we are merging a struct with longtype value tag and a string type. + // It results in merging the long type value tag with the primitive type + // and thus finally we got a struct with string type value tag. + .add(valueTagName, ArrayType(StringType)) + .add("num", LongType))))) + assert(xmlDF.schema === expectedSchema) + val expectedAns = Seq( + Row( + Seq("value1", "value8", "value12", "value13"), + Seq( + Row( + Seq("value2", "value3", "value4", "value5", "value6\n value7"), + Seq(Row(Seq("1", "2"), 1), Row(Seq("2"), null))), + Row( + Seq("value9", "value10", "value11"), + Seq(Row(null, 2), Row(null, null), Row(null, null), Row(Seq("3"), null))))), + Row( + null, + Seq( + Row( + Seq("value1"), null))), + Row( + Seq("value1"), + Seq( + Row( + null, + Seq(Row(Seq("1"), null)))))) + checkAnswer(xmlDF, expectedAns) + } + + test("value tag - user specifies a conflicting name for valueTag") { + val xmlDF = readData(valueTagConflictName, Map("valueTag" -> "a")) + val expectedSchema = new StructType().add("a", ArrayType(LongType)) + assert(xmlDF.schema == expectedSchema) + checkAnswer(xmlDF, Seq(Row(Seq(1, 2)))) + } + + test("value tag - comments") { + val xmlDF = readData(valueTagWithComments) + val expectedSchema = new StructType() + .add(valueTagName, LongType) + .add("a", ArrayType(new StructType().add("_attr", LongType))) + val expectedAns = Seq( + Row(2, Seq(Row(null), Row(1)))) + assert(xmlDF.schema === expectedSchema) + checkAnswer(xmlDF, expectedAns) + } + + test("value tags - CDATA") { + val xmlDF = readData(valueTagWithCDATA) + val expectedSchema = new StructType() + .add(valueTagName, StringType) + .add("a", new StructType() + .add(valueTagName, ArrayType(StringType)) + .add("b", ArrayType(LongType))) + + val expectedAns = Seq( + Row( + "This is a CDATA section containing <sample1> text.", + Row( + Seq( + "This is a CDATA section containing <sample2> text.\n" + + " This is a CDATA section containing <sample3> text.", + "This is a CDATA section containing <sample4> text." + ), + Seq(1, 2) + ) + ) + ) + assert(xmlDF.schema === expectedSchema) + checkAnswer(xmlDF, expectedAns) + } + + test("value tag - equals to null value") { + // we don't consider options.nullValue during schema inference + val xmlDF = readData(valueTagIsNullValue, Map("nullValue" -> "1")) + val expectedSchema = new StructType() + .add(valueTagName, LongType) + val expectedAns = Seq(Row(null)) + // nullValue option is used during parsing + assert(xmlDF.schema === expectedSchema) + checkAnswer(xmlDF, expectedAns) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org