This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a3c17b2e229 [SPARK-45562][SQL][FOLLOW-UP] XML: Make 'rowTag' option 
check case insensitive
a3c17b2e229 is described below

commit a3c17b2e22969de3d225fc9890023456592f6158
Author: Sandip Agarwala <131817656+sandip...@users.noreply.github.com>
AuthorDate: Thu Oct 19 13:23:04 2023 +0900

    [SPARK-45562][SQL][FOLLOW-UP] XML: Make 'rowTag' option check case 
insensitive
    
    ### What changes were proposed in this pull request?
    [PR 43389](https://github.com/apache/spark/pull/43389) made `rowTag` option 
required for XML read and write. However, the option check was done in a case 
sensitive manner. This PR makes the check case-insensitive.
    
    ### Why are the changes needed?
    Options are case-insensitive.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Unit test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43416 from sandip-db/xml-rowTagCaseInsensitive.
    
    Authored-by: Sandip Agarwala <131817656+sandip...@users.noreply.github.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../org/apache/spark/sql/catalyst/xml/XmlOptions.scala  | 17 +++++++++++------
 .../sql/execution/datasources/xml/XmlFileFormat.scala   |  5 ++---
 2 files changed, 13 insertions(+), 9 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
index 0dedbec58e1..d2c7b435fe6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
@@ -34,7 +34,8 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, 
SQLConf}
 private[sql] class XmlOptions(
     @transient val parameters: CaseInsensitiveMap[String],
     defaultTimeZoneId: String,
-    defaultColumnNameOfCorruptRecord: String)
+    defaultColumnNameOfCorruptRecord: String,
+    rowTagRequired: Boolean)
   extends FileSourceOptions(parameters) with Logging {
 
   import XmlOptions._
@@ -42,11 +43,13 @@ private[sql] class XmlOptions(
   def this(
       parameters: Map[String, String] = Map.empty,
       defaultTimeZoneId: String = SQLConf.get.sessionLocalTimeZone,
-      defaultColumnNameOfCorruptRecord: String = 
SQLConf.get.columnNameOfCorruptRecord) = {
+      defaultColumnNameOfCorruptRecord: String = 
SQLConf.get.columnNameOfCorruptRecord,
+      rowTagRequired: Boolean = false) = {
     this(
       CaseInsensitiveMap(parameters),
       defaultTimeZoneId,
-      defaultColumnNameOfCorruptRecord)
+      defaultColumnNameOfCorruptRecord,
+      rowTagRequired)
   }
 
   private def getBool(paramName: String, default: Boolean = false): Boolean = {
@@ -63,7 +66,9 @@ private[sql] class XmlOptions(
   }
 
   val compressionCodec = 
parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName)
-  val rowTag = parameters.getOrElse(ROW_TAG, XmlOptions.DEFAULT_ROW_TAG).trim
+  val rowTagOpt = parameters.get(XmlOptions.ROW_TAG)
+  require(!rowTagRequired || rowTagOpt.isDefined, s"'${XmlOptions.ROW_TAG}' 
option is required.")
+  val rowTag = rowTagOpt.getOrElse(XmlOptions.DEFAULT_ROW_TAG).trim
   require(rowTag.nonEmpty, s"'$ROW_TAG' option should not be an empty string.")
   require(!rowTag.startsWith("<") && !rowTag.endsWith(">"),
           s"'$ROW_TAG' should not include angle brackets")
@@ -223,8 +228,8 @@ private[sql] object XmlOptions extends DataSourceOptions {
   newOption(ENCODING, CHARSET)
 
   def apply(parameters: Map[String, String]): XmlOptions =
-    new XmlOptions(parameters, SQLConf.get.sessionLocalTimeZone)
+    new XmlOptions(parameters)
 
   def apply(): XmlOptions =
-    new XmlOptions(Map.empty, SQLConf.get.sessionLocalTimeZone)
+    new XmlOptions(Map.empty)
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
index 4342711b00f..77619299278 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
@@ -42,11 +42,10 @@ class XmlFileFormat extends TextBasedFileFormat with 
DataSourceRegister {
   def getXmlOptions(
       sparkSession: SparkSession,
       parameters: Map[String, String]): XmlOptions = {
-    val rowTagOpt = parameters.get(XmlOptions.ROW_TAG)
-    require(rowTagOpt.isDefined, s"'${XmlOptions.ROW_TAG}' option is 
required.")
     new XmlOptions(parameters,
       sparkSession.sessionState.conf.sessionLocalTimeZone,
-      sparkSession.sessionState.conf.columnNameOfCorruptRecord)
+      sparkSession.sessionState.conf.columnNameOfCorruptRecord,
+      true)
   }
 
   override def isSplitable(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to