This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a3c17b2e229 [SPARK-45562][SQL][FOLLOW-UP] XML: Make 'rowTag' option check case insensitive a3c17b2e229 is described below commit a3c17b2e22969de3d225fc9890023456592f6158 Author: Sandip Agarwala <131817656+sandip...@users.noreply.github.com> AuthorDate: Thu Oct 19 13:23:04 2023 +0900 [SPARK-45562][SQL][FOLLOW-UP] XML: Make 'rowTag' option check case insensitive ### What changes were proposed in this pull request? [PR 43389](https://github.com/apache/spark/pull/43389) made `rowTag` option required for XML read and write. However, the option check was done in a case sensitive manner. This PR makes the check case-insensitive. ### Why are the changes needed? Options are case-insensitive. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43416 from sandip-db/xml-rowTagCaseInsensitive. Authored-by: Sandip Agarwala <131817656+sandip...@users.noreply.github.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../org/apache/spark/sql/catalyst/xml/XmlOptions.scala | 17 +++++++++++------ .../sql/execution/datasources/xml/XmlFileFormat.scala | 5 ++--- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala index 0dedbec58e1..d2c7b435fe6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala @@ -34,7 +34,8 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} private[sql] class XmlOptions( @transient val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, - defaultColumnNameOfCorruptRecord: String) + defaultColumnNameOfCorruptRecord: String, + rowTagRequired: Boolean) extends FileSourceOptions(parameters) with Logging { import XmlOptions._ @@ -42,11 +43,13 @@ private[sql] class XmlOptions( def this( parameters: Map[String, String] = Map.empty, defaultTimeZoneId: String = SQLConf.get.sessionLocalTimeZone, - defaultColumnNameOfCorruptRecord: String = SQLConf.get.columnNameOfCorruptRecord) = { + defaultColumnNameOfCorruptRecord: String = SQLConf.get.columnNameOfCorruptRecord, + rowTagRequired: Boolean = false) = { this( CaseInsensitiveMap(parameters), defaultTimeZoneId, - defaultColumnNameOfCorruptRecord) + defaultColumnNameOfCorruptRecord, + rowTagRequired) } private def getBool(paramName: String, default: Boolean = false): Boolean = { @@ -63,7 +66,9 @@ private[sql] class XmlOptions( } val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName) - val rowTag = parameters.getOrElse(ROW_TAG, XmlOptions.DEFAULT_ROW_TAG).trim + val rowTagOpt = parameters.get(XmlOptions.ROW_TAG) + require(!rowTagRequired || rowTagOpt.isDefined, s"'${XmlOptions.ROW_TAG}' option is required.") + val rowTag = rowTagOpt.getOrElse(XmlOptions.DEFAULT_ROW_TAG).trim require(rowTag.nonEmpty, s"'$ROW_TAG' option should not be an empty string.") require(!rowTag.startsWith("<") && !rowTag.endsWith(">"), s"'$ROW_TAG' should not include angle brackets") @@ -223,8 +228,8 @@ private[sql] object XmlOptions extends DataSourceOptions { newOption(ENCODING, CHARSET) def apply(parameters: Map[String, String]): XmlOptions = - new XmlOptions(parameters, SQLConf.get.sessionLocalTimeZone) + new XmlOptions(parameters) def apply(): XmlOptions = - new XmlOptions(Map.empty, SQLConf.get.sessionLocalTimeZone) + new XmlOptions(Map.empty) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala index 4342711b00f..77619299278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala @@ -42,11 +42,10 @@ class XmlFileFormat extends TextBasedFileFormat with DataSourceRegister { def getXmlOptions( sparkSession: SparkSession, parameters: Map[String, String]): XmlOptions = { - val rowTagOpt = parameters.get(XmlOptions.ROW_TAG) - require(rowTagOpt.isDefined, s"'${XmlOptions.ROW_TAG}' option is required.") new XmlOptions(parameters, sparkSession.sessionState.conf.sessionLocalTimeZone, - sparkSession.sessionState.conf.columnNameOfCorruptRecord) + sparkSession.sessionState.conf.columnNameOfCorruptRecord, + true) } override def isSplitable( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org