This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b7cdff929eb9 [SPARK-48100][SQL] Fix issues in skipping nested 
structure fields not selected in schema
b7cdff929eb9 is described below

commit b7cdff929eb92bce6661d73a2ccf8c7f6170b471
Author: Shujing Yang <shujing.y...@databricks.com>
AuthorDate: Wed May 15 13:59:11 2024 +0900

    [SPARK-48100][SQL] Fix issues in skipping nested structure fields not 
selected in schema
    
    ### What changes were proposed in this pull request?
    
    Previously, the XML parser can't skip nested structure data fields 
effectively when they were not selected in the schema. For instance, in the 
below example, `df.select("struct2").collect()` returns `Seq(null)` as 
`struct1` wasn't effectively skipped. This PR fixes this issue.
    
    ```
    <ROW>
      <struct1>
        <innerStruct><field1>1</field1></innerStruct>
      </struct1>
      <struct2>
        <field2>2</field2>
      </struct2>
    </ROW>
    ```
    
    We also added more tests regarding projection in this PR.
    ### Why are the changes needed?
    
    Fix a bug.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    UTs
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #46348 from shujingyang-db/fix-skip-children.
    
    Lead-authored-by: Shujing Yang <shujing.y...@databricks.com>
    Co-authored-by: Shujing Yang 
<135740748+shujingyang...@users.noreply.github.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../spark/sql/catalyst/xml/StaxXmlParser.scala     |   3 +-
 .../sql/catalyst/xml/StaxXmlParserUtils.scala      |  38 ++---
 .../sql/execution/datasources/xml/XmlSuite.scala   | 157 +++++++++++++++++++++
 .../xml/parsers/StaxXmlParserUtilsSuite.scala      |  39 +++--
 4 files changed, 207 insertions(+), 30 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
index ab671e56a21e..9a0528468842 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
@@ -397,8 +397,7 @@ class StaxXmlParser(
                     row(anyIndex) = values :+ newValue
                 }
               } else {
-                StaxXmlParserUtils.skipChildren(parser)
-                StaxXmlParserUtils.skipNextEndElement(parser, field, options)
+                StaxXmlParserUtils.skipChildren(parser, field, options)
               }
           }
         } catch {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala
index a59ea6f460de..5d267143b06c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala
@@ -165,24 +165,27 @@ object StaxXmlParserUtils {
 
   /**
    * Skip the children of the current XML element.
+   * Before this function is called, the 'startElement' of the object has 
already been consumed.
+   * Upon completion, this function consumes the 'endElement' of the object,
+   * which effectively skipping the entire object enclosed within these 
elements.
    */
-  def skipChildren(parser: XMLEventReader): Unit = {
-    var shouldStop = checkEndElement(parser)
+  def skipChildren(
+      parser: XMLEventReader,
+      expectedNextEndElementName: String,
+      options: XmlOptions): Unit = {
+    var shouldStop = false
     while (!shouldStop) {
       parser.nextEvent match {
-        case _: StartElement =>
-          val e = parser.peek
-          if (e.isCharacters && e.asCharacters.isWhiteSpace) {
-            // There can be a `Characters` event between `StartElement`s.
-            // So, we need to check further to decide if this is a data or just
-            // a whitespace between them.
-            parser.next
-          }
-          if (parser.peek.isStartElement) {
-            skipChildren(parser)
-          }
-        case _: EndElement =>
-          shouldStop = checkEndElement(parser)
+        case startElement: StartElement =>
+          val childField = 
StaxXmlParserUtils.getName(startElement.asStartElement.getName, options)
+          skipChildren(parser, childField, options)
+        case endElement: EndElement =>
+          val endElementName = getName(endElement.getName, options)
+          assert(
+            endElementName == expectedNextEndElementName,
+            s"Expected EndElement </$expectedNextEndElementName>, but found 
</$endElementName>"
+          )
+          shouldStop = true
         case _: XMLEvent => // do nothing
       }
     }
@@ -197,9 +200,10 @@ object StaxXmlParserUtils {
       case c: Characters if c.isWhiteSpace =>
         skipNextEndElement(parser, expectedNextEndElementName, options)
       case endElement: EndElement =>
+        val endElementName = getName(endElement.getName, options)
         assert(
-          getName(endElement.getName, options) == expectedNextEndElementName,
-          s"Expected EndElement </$expectedNextEndElementName>")
+          endElementName == expectedNextEndElementName,
+          s"Expected EndElement </$expectedNextEndElementName>, but found 
</$endElementName>")
       case _ => throw new IllegalStateException(
         s"Expected EndElement </$expectedNextEndElementName>")
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
index 51e8cfc7f103..1b39132c2fd7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
@@ -77,6 +77,20 @@ class XmlSuite
   override def excluded: Seq[String] = Seq(
     s"Propagate Hadoop configs from $dataSourceFormat options to underlying 
file system")
 
+  private val baseOptions = Map("rowTag" -> "ROW")
+
+  private def readData(
+      xmlString: String,
+      schemaOpt: Option[StructType],
+      options: Map[String, String] = Map.empty): DataFrame = {
+    val ds = 
spark.createDataset(spark.sparkContext.parallelize(Seq(xmlString)))(Encoders.STRING)
+    if (schemaOpt.isDefined) {
+      spark.read.schema(schemaOpt.get).options(options).xml(ds)
+    } else {
+      spark.read.options(options).xml(ds)
+    }
+  }
+
   // Tests
 
   test("DSL test") {
@@ -3027,6 +3041,149 @@ class XmlSuite
       }
     }
   }
+
+  /////////////////////////////////////
+  // Projection, sorting, filtering  //
+  /////////////////////////////////////
+  test("select with string xml object") {
+    val xmlString =
+      s"""
+         |<ROW>
+         |    <name>John</name>
+         |    <metadata><id>3</id></metadata>
+         |</ROW>
+         |""".stripMargin
+    val schema = new StructType()
+      .add("name", StringType)
+      .add("metadata", StringType)
+    val df = readData(xmlString, Some(schema), baseOptions)
+    checkAnswer(df.select("name"), Seq(Row("John")))
+  }
+
+  test("select with duplicate field name in string xml object") {
+    val xmlString =
+      s"""
+         |<ROW>
+         |    <a><b>c</b></a>
+         |    <b>d</b>
+         |</ROW>
+         |""".stripMargin
+    val schema = new StructType()
+      .add("a", StringType)
+      .add("b", StringType)
+    val df = readData(xmlString, Some(schema), baseOptions)
+    val dfWithSchemaInference = readData(xmlString, None, baseOptions)
+    Seq(df, dfWithSchemaInference).foreach { df =>
+      checkAnswer(df.select("b"), Seq(Row("d")))
+    }
+  }
+
+  test("select nested struct objects") {
+    val xmlString =
+      s"""
+         |<ROW>
+         |    <struct>
+         |        <innerStruct>
+         |            <field1>1</field1>
+         |            <field2>2</field2>
+         |        </innerStruct>
+         |    </struct>
+         |</ROW>
+         |""".stripMargin
+    val schema = new StructType()
+      .add(
+        "struct",
+        new StructType()
+          .add("innerStruct", new StructType().add("field1", 
LongType).add("field2", LongType))
+      )
+    val df = readData(xmlString, Some(schema), baseOptions)
+    val dfWithSchemaInference = readData(xmlString, None, baseOptions)
+    Seq(df, dfWithSchemaInference).foreach { df =>
+      checkAnswer(df.select("struct"), Seq(Row(Row(Row(1, 2)))))
+      checkAnswer(df.select("struct.innerStruct"), Seq(Row(Row(1, 2))))
+    }
+  }
+
+  test("select a struct of lists") {
+    val xmlString =
+      s"""
+         |<ROW>
+         |    <struct>
+         |        <array><field>1</field></array>
+         |        <array><field>2</field></array>
+         |        <array><field>3</field></array>
+         |    </struct>
+         |</ROW>
+         |""".stripMargin
+    val schema = new StructType()
+      .add(
+        "struct",
+        new StructType()
+          .add("array", ArrayType(StructType(StructField("field", LongType) :: 
Nil))))
+
+    val df = readData(xmlString, Some(schema), baseOptions)
+    val dfWithSchemaInference = readData(xmlString, None, baseOptions)
+    Seq(df, dfWithSchemaInference).foreach { df =>
+      checkAnswer(df.select("struct"), Seq(Row(Row(Array(Row(1), Row(2), 
Row(3))))))
+      checkAnswer(df.select("struct.array"), Seq(Row(Array(Row(1), Row(2), 
Row(3)))))
+    }
+  }
+
+  test("select complex objects") {
+    val xmlString =
+      s"""
+         |<ROW>
+         |    1
+         |    <struct1>
+         |        value2
+         |        <struct2>
+         |            3
+         |            <array1>
+         |                value4
+         |                <struct3>
+         |                    5
+         |                    <array2>1<!--First comment--> <!--Second 
comment--></array2>
+         |                    value6
+         |                    <array2>2</array2>
+         |                    7
+         |                </struct3>
+         |                value8
+         |                <string>string</string>
+         |                9
+         |            </array1>
+         |            value10
+         |            <array1>
+         |                <struct3><!--First comment--> <!--Second comment-->
+         |                    <array2>3</array2>
+         |                    11
+         |                    <array2>4</array2><!--First comment--> 
<!--Second comment-->
+         |                </struct3>
+         |                <string>string</string>
+         |                value12
+         |            </array1>
+         |            13
+         |            <int>3</int>
+         |            value14
+         |        </struct2>
+         |        15
+         |    </struct1>
+         |     <!--First comment-->
+         |    value16
+         |     <!--Second comment-->
+         |</ROW>
+         |""".stripMargin
+    val df = readData(xmlString, None, baseOptions ++ Map("valueTag" -> 
"VALUE"))
+    checkAnswer(df.select("struct1.VALUE"), Seq(Row(Seq("value2", "15"))))
+    checkAnswer(df.select("struct1.struct2.array1"), Seq(Row(Seq(
+      Row(Seq("value4", "value8", "9"), "string", Row(Seq("5", "value6", "7"), 
Seq(1, 2))),
+      Row(Seq("value12"), "string", Row(Seq("11"), Seq(3, 4)))
+    ))))
+    checkAnswer(df.select("struct1.struct2.array1.struct3"), Seq(Row(Seq(
+      Row(Seq("5", "value6", "7"), Seq(1, 2)),
+      Row(Seq("11"), Seq(3, 4))
+    ))))
+    checkAnswer(df.select("struct1.struct2.array1.string"), 
Seq(Row(Seq("string", "string"))))
+  }
 }
 
 // Mock file system that checks the number of open files
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala
index a4ac25b036c4..ad5b176f71f7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlParserUtilsSuite.scala
@@ -73,17 +73,34 @@ final class StaxXmlParserUtilsSuite extends SparkFunSuite 
with BeforeAndAfterAll
     val input = <ROW><info>
       <name>Sam Mad Dog Smith</name><amount><small>1</small>
         <large>9</large></amount></info><abc>2</abc><test>2</test></ROW>
-    val parser = factory.createXMLEventReader(new StringReader(input.toString))
-    // We assume here it's reading the value within `id` field.
-    StaxXmlParserUtils.skipUntil(parser, XMLStreamConstants.CHARACTERS)
-    StaxXmlParserUtils.skipChildren(parser)
-    assert(parser.nextEvent().asEndElement().getName.getLocalPart === "info")
-    parser.next()
-    StaxXmlParserUtils.skipChildren(parser)
-    assert(parser.nextEvent().asEndElement().getName.getLocalPart === "abc")
-    parser.next()
-    StaxXmlParserUtils.skipChildren(parser)
-    assert(parser.nextEvent().asEndElement().getName.getLocalPart === "test")
+    val xmlOptions = new XmlOptions()
+    // skip the entire row
+    val parser1 = factory.createXMLEventReader(new 
StringReader(input.toString))
+    StaxXmlParserUtils.skipUntil(parser1, XMLStreamConstants.START_ELEMENT)
+    StaxXmlParserUtils.skipChildren(parser1, "ROW", xmlOptions)
+    assert(parser1.peek().getEventType === XMLStreamConstants.END_DOCUMENT)
+
+    // skip <name> and <amount> respectively
+    val parser2 = factory.createXMLEventReader(new 
StringReader(input.toString))
+    StaxXmlParserUtils.skipUntil(parser2, XMLStreamConstants.CHARACTERS)
+    // skip <name>
+    val elementName1 =
+      StaxXmlParserUtils.getName(parser2.nextEvent().asStartElement().getName, 
xmlOptions)
+    StaxXmlParserUtils.skipChildren(parser2, elementName1, xmlOptions)
+    assert(parser2.peek().getEventType === XMLStreamConstants.START_ELEMENT)
+    val elementName2 =
+      StaxXmlParserUtils.getName(parser2.peek().asStartElement().getName, 
xmlOptions)
+    assert(
+      StaxXmlParserUtils
+        .getName(parser2.peek().asStartElement().getName, xmlOptions) == 
elementName2
+    )
+    // skip <amount>
+    parser2.nextEvent()
+    StaxXmlParserUtils.skipChildren(parser2, elementName2, xmlOptions)
+    assert(parser2.peek().getEventType === XMLStreamConstants.END_ELEMENT)
+    assert(
+      StaxXmlParserUtils.getName(parser2.peek().asEndElement().getName, 
xmlOptions) == "info"
+    )
   }
 
   test("XML Input Factory disables DTD parsing") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to