This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b7cdff929eb9 [SPARK-48100][SQL] Fix issues in skipping nested structure fields not selected in schema b7cdff929eb9 is described below commit b7cdff929eb92bce6661d73a2ccf8c7f6170b471 Author: Shujing Yang <shujing.y...@databricks.com> AuthorDate: Wed May 15 13:59:11 2024 +0900 [SPARK-48100][SQL] Fix issues in skipping nested structure fields not selected in schema ### What changes were proposed in this pull request? Previously, the XML parser can't skip nested structure data fields effectively when they were not selected in the schema. For instance, in the below example, `df.select("struct2").collect()` returns `Seq(null)` as `struct1` wasn't effectively skipped. This PR fixes this issue. ``` <ROW> <struct1> <innerStruct><field1>1</field1></innerStruct> </struct1> <struct2> <field2>2</field2> </struct2> </ROW> ``` We also added more tests regarding projection in this PR. ### Why are the changes needed? Fix a bug. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #46348 from shujingyang-db/fix-skip-children. Lead-authored-by: Shujing Yang <shujing.y...@databricks.com> Co-authored-by: Shujing Yang <135740748+shujingyang...@users.noreply.github.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../spark/sql/catalyst/xml/StaxXmlParser.scala | 3 +- .../sql/catalyst/xml/StaxXmlParserUtils.scala | 38 ++--- .../sql/execution/datasources/xml/XmlSuite.scala | 157 +++++++++++++++++++++ .../xml/parsers/StaxXmlParserUtilsSuite.scala | 39 +++-- 4 files changed, 207 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index ab671e56a21e..9a0528468842 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -397,8 +397,7 @@ class StaxXmlParser( row(anyIndex) = values :+ newValue } } else { - StaxXmlParserUtils.skipChildren(parser) - StaxXmlParserUtils.skipNextEndElement(parser, field, options) + StaxXmlParserUtils.skipChildren(parser, field, options) } } } catch { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala index a59ea6f460de..5d267143b06c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala @@ -165,24 +165,27 @@ object StaxXmlParserUtils { /** * Skip the children of the current XML element. + * Before this function is called, the 'startElement' of the object has already been consumed. + * Upon completion, this function consumes the 'endElement' of the object, + * which effectively skipping the entire object enclosed within these elements. */ - def skipChildren(parser: XMLEventReader): Unit = { - var shouldStop = checkEndElement(parser) + def skipChildren( + parser: XMLEventReader, + expectedNextEndElementName: String, + options: XmlOptions): Unit = { + var shouldStop = false while (!shouldStop) { parser.nextEvent match { - case _: StartElement => - val e = parser.peek - if (e.isCharacters && e.asCharacters.isWhiteSpace) { - // There can be a `Characters` event between `StartElement`s. - // So, we need to check further to decide if this is a data or just - // a whitespace between them. - parser.next - } - if (parser.peek.isStartElement) { - skipChildren(parser) - } - case _: EndElement => - shouldStop = checkEndElement(parser) + case startElement: StartElement => + val childField = StaxXmlParserUtils.getName(startElement.asStartElement.getName, options) + skipChildren(parser, childField, options) + case endElement: EndElement => + val endElementName = getName(endElement.getName, options) + assert( + endElementName == expectedNextEndElementName, + s"Expected EndElement </$expectedNextEndElementName>, but found </$endElementName>" + ) + shouldStop = true case _: XMLEvent => // do nothing } } @@ -197,9 +200,10 @@ object StaxXmlParserUtils { case c: Characters if c.isWhiteSpace => skipNextEndElement(parser, expectedNextEndElementName, options) case endElement: EndElement => + val endElementName = getName(endElement.getName, options) assert( - getName(endElement.getName, options) == expectedNextEndElementName, - s"Expected EndElement </$expectedNextEndElementName>") + endElementName == expectedNextEndElementName, + s"Expected EndElement </$expectedNextEndElementName>, but found </$endElementName>") case _ => throw new IllegalStateException( s"Expected EndElement </$expectedNextEndElementName>") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 51e8cfc7f103..1b39132c2fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -77,6 +77,20 @@ class XmlSuite override def excluded: Seq[String] = Seq( s"Propagate Hadoop configs from $dataSourceFormat options to underlying file system") + private val baseOptions = Map("rowTag" -> "ROW") + + private def readData( + xmlString: String, + schemaOpt: Option[StructType], + options: Map[String, String] = Map.empty): DataFrame = { + val ds = spark.createDataset(spark.sparkContext.parallelize(Seq(xmlString)))(Encoders.STRING) + if (schemaOpt.isDefined) { + spark.read.schema(schemaOpt.get).options(options).xml(ds) + } else { + spark.read.options(options).xml(ds) + } + } + // Tests test("DSL test") { @@ -3027,6 +3041,149 @@ class XmlSuite } } } + + ///////////////////////////////////// + // Projection, sorting, filtering // + ///////////////////////////////////// + test("select with string xml object") { + val xmlString = + s""" + |<ROW> + | <name>John</name> + | <metadata><id>3</id></metadata> + |</ROW> + |""".stripMargin + val schema = new StructType() + .add("name", StringType) + .add("metadata", StringType) + val df = readData(xmlString, Some(schema), baseOptions) + checkAnswer(df.select("name"), Seq(Row("John"))) + } + + test("select with duplicate field name in string xml object") { + val xmlString = + s""" + |<ROW> + | <a><b>c</b></a> + | <b>d</b> + |</ROW> + |""".stripMargin + val schema = new StructType() + .add("a", StringType) + .add("b", StringType) + val df = readData(xmlString, Some(schema), baseOptions) + val dfWithSchemaInference = readData(xmlString, None, baseOptions) + Seq(df, dfWithSchemaInference).foreach { df => + checkAnswer(df.select("b"), Seq(Row("d"))) + } + } + + test("select nested struct objects") { + val xmlString = + s""" + |<ROW> + | <struct> + | <innerStruct> + | <field1>1</field1> + | <field2>2</field2> + | </innerStruct> + | </struct> + |</ROW> + |""".stripMargin + val schema = new StructType() + .add( + "struct", + new StructType() + .add("innerStruct", new StructType().add("field1", LongType).add("field2", LongType)) + ) + val df = readData(xmlString, Some(schema), baseOptions) + val dfWithSchemaInference = readData(xmlString, None, baseOptions) + Seq(df, dfWithSchemaInference).foreach { df => + checkAnswer(df.select("struct"), Seq(Row(Row(Row(1, 2))))) + checkAnswer(df.select("struct.innerStruct"), Seq(Row(Row(1, 2)))) + } + } + + test("select a struct of lists") { + val xmlString = + s""" + |<ROW> + | <struct> + | <array><field>1</field></array> + | <array><field>2</field></array> + | <array><field>3</field></array> + | </struct> + |</ROW> + |""".stripMargin + val schema = new StructType() + .add( + "struct", + new StructType() + .add("array", ArrayType(StructType(StructField("field", LongType) :: Nil)))) + + val df = readData(xmlString, Some(schema), baseOptions) + val dfWithSchemaInference = readData(xmlString, None, baseOptions) + Seq(df, dfWithSchemaInference).foreach { df => + checkAnswer(df.select("struct"), Seq(Row(Row(Array(Row(1), Row(2), Row(3)))))) + checkAnswer(df.select("struct.array"), Seq(Row(Array(Row(1), Row(2), Row(3))))) + } + } + + test("select complex objects") { + val xmlString = + s""" + |<ROW> + | 1 + | <struct1> + | value2 + | <struct2> + | 3 + | <array1> + | value4 + | <struct3> + | 5 + | <array2>1<!--First comment--> <!--Second comment--></array2> + | value6 + | <array2>2</array2> + | 7 + | </struct3> + | value8 + | <string>string</string> + | 9 + | </array1> + | value10 + | <array1> + | <struct3><!--First comment--> <!--Second comment--> + | <array2>3</array2> + | 11 + | <array2>4</array2><!--First comment--> <!--Second comment--> + | </struct3> + | <string>string</string> + | value12 + | </array1> + | 13 + | <int>3</int> + | value14 + | </struct2> + | 15 + | </struct1> + | <!--First comment--> + | value16 + | <!--Second comment--> + |</ROW> + |""".stripMargin + val df = readData(xmlString, None, baseOptions ++ Map("valueTag" -> "VALUE")) + checkAnswer(df.select("struct1.VALUE"), Seq(Row(Seq("value2", "15")))) + checkAnswer(df.select("struct1.struct2.array1"), Seq(Row(Seq( + Row(Seq("value4", "value8", "9"), "string", Row(Seq("5", "value6", "7"), Seq(1, 2))), + Row(Seq("value12"), "string", Row(Seq("11"), Seq(3, 4))) + )))) + checkAnswer(df.select("struct1.struct2.array1.struct3"), Seq(Row(Seq( + Row(Seq("5", "value6", "7"), Seq(1, 2)), + Row(Seq("11"), Seq(3, 4)) + )))) + checkAnswer(df.select("struct1.struct2.array1.string"), Seq(Row(Seq("string", "string")))) + } } // Mock file system that checks the number of open files diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala index a4ac25b036c4..ad5b176f71f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala @@ -73,17 +73,34 @@ final class StaxXmlParserUtilsSuite extends SparkFunSuite with BeforeAndAfterAll val input = <ROW><info> <name>Sam Mad Dog Smith</name><amount><small>1</small> <large>9</large></amount></info><abc>2</abc><test>2</test></ROW> - val parser = factory.createXMLEventReader(new StringReader(input.toString)) - // We assume here it's reading the value within `id` field. - StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.CHARACTERS) - StaxXmlParserUtils.skipChildren(parser) - assert(parser.nextEvent().asEndElement().getName.getLocalPart === "info") - parser.next() - StaxXmlParserUtils.skipChildren(parser) - assert(parser.nextEvent().asEndElement().getName.getLocalPart === "abc") - parser.next() - StaxXmlParserUtils.skipChildren(parser) - assert(parser.nextEvent().asEndElement().getName.getLocalPart === "test") + val xmlOptions = new XmlOptions() + // skip the entire row + val parser1 = factory.createXMLEventReader(new StringReader(input.toString)) + StaxXmlParserUtils.skipUntil(parser1, XMLStreamConstants.START_ELEMENT) + StaxXmlParserUtils.skipChildren(parser1, "ROW", xmlOptions) + assert(parser1.peek().getEventType === XMLStreamConstants.END_DOCUMENT) + + // skip <name> and <amount> respectively + val parser2 = factory.createXMLEventReader(new StringReader(input.toString)) + StaxXmlParserUtils.skipUntil(parser2, XMLStreamConstants.CHARACTERS) + // skip <name> + val elementName1 = + StaxXmlParserUtils.getName(parser2.nextEvent().asStartElement().getName, xmlOptions) + StaxXmlParserUtils.skipChildren(parser2, elementName1, xmlOptions) + assert(parser2.peek().getEventType === XMLStreamConstants.START_ELEMENT) + val elementName2 = + StaxXmlParserUtils.getName(parser2.peek().asStartElement().getName, xmlOptions) + assert( + StaxXmlParserUtils + .getName(parser2.peek().asStartElement().getName, xmlOptions) == elementName2 + ) + // skip <amount> + parser2.nextEvent() + StaxXmlParserUtils.skipChildren(parser2, elementName2, xmlOptions) + assert(parser2.peek().getEventType === XMLStreamConstants.END_ELEMENT) + assert( + StaxXmlParserUtils.getName(parser2.peek().asEndElement().getName, xmlOptions) == "info" + ) } test("XML Input Factory disables DTD parsing") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org