This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 50d137d752e5 [SPARK-44751][SQL] XML: Refactor TypeCast and 
timestampFormatter
50d137d752e5 is described below

commit 50d137d752e5be61bcd7c28754eda7b300eb6e27
Author: Sandip Agarwala <131817656+sandip...@users.noreply.github.com>
AuthorDate: Tue Nov 7 09:33:02 2023 -0800

    [SPARK-44751][SQL] XML: Refactor TypeCast and timestampFormatter
    
    ### What changes were proposed in this pull request?
    - Move initialization of TimestampFormatter from XmlOptions to 
StaxXmlParser, StaxXmlGenerator and XmlInferSchema.
    - Move functions from  typecast.scala to StaxXmlParser or XmlInferSchema
    - Convert XmlInferSchema to a class
    
    ### Why are the changes needed?
    Some of the timestampformatter fields were not correctly initialized when 
accessed in StaxXmlParser in the executor. This was resulting in some timestamp 
parsing failures. Moving the initialization of timestampformatter to 
StaxXmlParser fixed the issue.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added new unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43697 from sandip-db/xml-typecast.
    
    Authored-by: Sandip Agarwala <131817656+sandip...@users.noreply.github.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../sql/catalyst/expressions/xmlExpressions.scala  |  11 +-
 .../spark/sql/catalyst/xml/StaxXmlGenerator.scala  |  29 ++-
 .../spark/sql/catalyst/xml/StaxXmlParser.scala     | 224 ++++++++++++++++---
 .../apache/spark/sql/catalyst/xml/TypeCast.scala   | 244 ---------------------
 .../spark/sql/catalyst/xml/XmlInferSchema.scala    | 158 ++++++++++---
 .../apache/spark/sql/catalyst/xml/XmlOptions.scala |  36 +--
 .../org/apache/spark/sql/DataFrameReader.scala     |   5 +-
 .../execution/datasources/xml/XmlDataSource.scala  |   6 +-
 .../spark/sql/streaming/DataStreamReader.scala     |   6 +-
 .../sql/execution/datasources/xml/XmlSuite.scala   | 182 ++++++++++++++-
 .../datasources/xml/util/TypeCastSuite.scala       | 234 --------------------
 11 files changed, 546 insertions(+), 589 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
index 047b669fc896..c581643460f6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala
@@ -188,6 +188,9 @@ case class SchemaOfXml(
   @transient
   private lazy val xmlFactory = xmlOptions.buildXmlFactory()
 
+  @transient
+  private lazy val xmlInferSchema = new XmlInferSchema(xmlOptions)
+
   @transient
   private lazy val xml = child.eval().asInstanceOf[UTF8String]
 
@@ -209,16 +212,16 @@ case class SchemaOfXml(
   }
 
   override def eval(v: InternalRow): Any = {
-    val dataType = XmlInferSchema.infer(xml.toString, xmlOptions).get match {
+    val dataType = xmlInferSchema.infer(xml.toString).get match {
       case st: StructType =>
-        XmlInferSchema.canonicalizeType(st).getOrElse(StructType(Nil))
+        xmlInferSchema.canonicalizeType(st).getOrElse(StructType(Nil))
       case at: ArrayType if at.elementType.isInstanceOf[StructType] =>
-        XmlInferSchema
+        xmlInferSchema
           .canonicalizeType(at.elementType)
           .map(ArrayType(_, containsNull = at.containsNull))
           .getOrElse(ArrayType(StructType(Nil), containsNull = 
at.containsNull))
       case other: DataType =>
-        XmlInferSchema.canonicalizeType(other).getOrElse(StringType)
+        xmlInferSchema.canonicalizeType(other).getOrElse(StringType)
     }
 
     UTF8String.fromString(dataType.sql)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala
index 4477cf50823c..ae3a64d865cf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala
@@ -26,7 +26,8 @@ import com.sun.xml.txw2.output.IndentingXMLStreamWriter
 import org.apache.hadoop.shaded.com.ctc.wstx.api.WstxOutputProperties
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, MapData, 
TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -40,6 +41,26 @@ class StaxXmlGenerator(
     "'attributePrefix' option should not be empty string.")
   private val indentDisabled = options.indent == ""
 
+  private val timestampFormatter = TimestampFormatter(
+    options.timestampFormatInWrite,
+    options.zoneId,
+    options.locale,
+    legacyFormat = FAST_DATE_FORMAT,
+    isParsing = false)
+
+  private val timestampNTZFormatter = TimestampFormatter(
+    options.timestampNTZFormatInWrite,
+    options.zoneId,
+    legacyFormat = FAST_DATE_FORMAT,
+    isParsing = false,
+    forTimestampNTZ = true)
+
+  private val dateFormatter = DateFormatter(
+    options.dateFormatInWrite,
+    options.locale,
+    legacyFormat = FAST_DATE_FORMAT,
+    isParsing = false)
+
   private val gen = {
     val factory = XMLOutputFactory.newInstance()
     // to_xml disables structure validation to allow multiple root tags
@@ -149,11 +170,11 @@ class StaxXmlGenerator(
     case (StringType, v: UTF8String) => gen.writeCharacters(v.toString)
     case (StringType, v: String) => gen.writeCharacters(v)
     case (TimestampType, v: Timestamp) =>
-      
gen.writeCharacters(options.timestampFormatterInWrite.format(v.toInstant()))
+      gen.writeCharacters(timestampFormatter.format(v.toInstant()))
     case (TimestampType, v: Long) =>
-      gen.writeCharacters(options.timestampFormatterInWrite.format(v))
+      gen.writeCharacters(timestampFormatter.format(v))
     case (DateType, v: Int) =>
-      gen.writeCharacters(options.dateFormatterInWrite.format(v))
+      gen.writeCharacters(dateFormatter.format(v))
     case (IntegerType, v: Int) => gen.writeCharacters(v.toString)
     case (ShortType, v: Short) => gen.writeCharacters(v.toString)
     case (FloatType, v: Float) => gen.writeCharacters(v.toString)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
index dcb760aca9d2..77a0bd1dff17 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
@@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.xml
 
 import java.io.{CharConversionException, InputStream, InputStreamReader, 
StringReader}
 import java.nio.charset.{Charset, MalformedInputException}
+import java.text.NumberFormat
+import java.util.Locale
 import javax.xml.stream.{XMLEventReader, XMLStreamException}
 import javax.xml.stream.events._
 import javax.xml.transform.stream.StreamSource
@@ -25,15 +27,17 @@ import javax.xml.validation.Schema
 
 import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
+import scala.util.Try
 import scala.util.control.NonFatal
 import scala.xml.SAXException
 
 import org.apache.spark.SparkUpgradeException
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, 
BadRecordException, DropMalformedMode, FailureSafeParser, GenericArrayData, 
MapData, ParseMode, PartialResultArrayException, PartialResultException, 
PermissiveMode}
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, 
BadRecordException, DateFormatter, DropMalformedMode, FailureSafeParser, 
GenericArrayData, MapData, ParseMode, PartialResultArrayException, 
PartialResultException, PermissiveMode, TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
 import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream
-import org.apache.spark.sql.catalyst.xml.TypeCast._
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.sources.Filter
 import org.apache.spark.sql.types._
@@ -46,6 +50,29 @@ class StaxXmlParser(
 
   private val factory = options.buildXmlFactory()
 
+  private lazy val timestampFormatter = TimestampFormatter(
+    options.timestampFormatInRead,
+    options.zoneId,
+    options.locale,
+    legacyFormat = FAST_DATE_FORMAT,
+    isParsing = true)
+
+  private lazy val timestampNTZFormatter = TimestampFormatter(
+    options.timestampNTZFormatInRead,
+    options.zoneId,
+    legacyFormat = FAST_DATE_FORMAT,
+    isParsing = true,
+    forTimestampNTZ = true)
+
+  private lazy val dateFormatter = DateFormatter(
+    options.dateFormatInRead,
+    options.locale,
+    legacyFormat = FAST_DATE_FORMAT,
+    isParsing = true)
+
+  private val decimalParser = ExprUtils.getDecimalParser(options.locale)
+
+
   /**
    * Parses a single XML string and turns it into either one resulting row or 
no row (if the
    * the record is malformed).
@@ -108,7 +135,7 @@ class StaxXmlParser(
       val isRootAttributesOnly = schema.fields.forall { f =>
         f.name == options.valueTag || 
f.name.startsWith(options.attributePrefix)
       }
-      Some(convertObject(parser, schema, options, rootAttributes, 
isRootAttributesOnly))
+      Some(convertObject(parser, schema, rootAttributes, isRootAttributesOnly))
     } catch {
       case e: SparkUpgradeException => throw e
       case e@(_: RuntimeException | _: XMLStreamException | _: 
MalformedInputException
@@ -145,15 +172,14 @@ class StaxXmlParser(
   private[xml] def convertField(
       parser: XMLEventReader,
       dataType: DataType,
-      options: XmlOptions,
       attributes: Array[Attribute] = Array.empty): Any = {
 
     def convertComplicatedType(dt: DataType, attributes: Array[Attribute]): 
Any = dt match {
-      case st: StructType => convertObject(parser, st, options)
-      case MapType(StringType, vt, _) => convertMap(parser, vt, options, 
attributes)
-      case ArrayType(st, _) => convertField(parser, st, options)
+      case st: StructType => convertObject(parser, st)
+      case MapType(StringType, vt, _) => convertMap(parser, vt, attributes)
+      case ArrayType(st, _) => convertField(parser, st)
       case _: StringType =>
-        convertTo(StaxXmlParserUtils.currentStructureAsString(parser), 
StringType, options)
+        convertTo(StaxXmlParserUtils.currentStructureAsString(parser), 
StringType)
     }
 
     (parser.peek, dataType) match {
@@ -168,7 +194,7 @@ class StaxXmlParser(
       case (_: EndElement, _: DataType) => null
       case (c: Characters, ArrayType(st, _)) =>
         // For `ArrayType`, it needs to return the type of element. The values 
are merged later.
-        convertTo(c.getData, st, options)
+        convertTo(c.getData, st)
       case (c: Characters, st: StructType) =>
         // If a value tag is present, this can be an attribute-only element 
whose values is in that
         // value tag field. Or, it can be a mixed-type element with both some 
character elements
@@ -180,18 +206,18 @@ class StaxXmlParser(
           // If everything else is an attribute column, there's no complex 
structure.
           // Just return the value of the character element, or null if we 
don't have a value tag
           st.find(_.name == options.valueTag).map(
-            valueTag => convertTo(c.getData, valueTag.dataType, 
options)).orNull
+            valueTag => convertTo(c.getData, valueTag.dataType)).orNull
         } else {
           // Otherwise, ignore this character element, and continue parsing 
the following complex
           // structure
           parser.next
           parser.peek match {
             case _: EndElement => null // no struct here at all; done
-            case _ => convertObject(parser, st, options)
+            case _ => convertObject(parser, st)
           }
         }
       case (_: Characters, _: StringType) =>
-        convertTo(StaxXmlParserUtils.currentStructureAsString(parser), 
StringType, options)
+        convertTo(StaxXmlParserUtils.currentStructureAsString(parser), 
StringType)
       case (c: Characters, _: DataType) if c.isWhiteSpace =>
         // When `Characters` is found, we need to look further to decide
         // if this is really data or space between other elements.
@@ -201,11 +227,11 @@ class StaxXmlParser(
           case _: StartElement => convertComplicatedType(dataType, attributes)
           case _: EndElement if data.isEmpty => null
           case _: EndElement if options.treatEmptyValuesAsNulls => null
-          case _: EndElement => convertTo(data, dataType, options)
-          case _ => convertField(parser, dataType, options, attributes)
+          case _: EndElement => convertTo(data, dataType)
+          case _ => convertField(parser, dataType, attributes)
         }
       case (c: Characters, dt: DataType) =>
-        convertTo(c.getData, dt, options)
+        convertTo(c.getData, dt)
       case (e: XMLEvent, dt: DataType) =>
         throw new IllegalArgumentException(
           s"Failed to parse a value for data type $dt with event 
${e.toString}")
@@ -218,12 +244,11 @@ class StaxXmlParser(
   private def convertMap(
       parser: XMLEventReader,
       valueType: DataType,
-      options: XmlOptions,
       attributes: Array[Attribute]): MapData = {
     val kvPairs = ArrayBuffer.empty[(UTF8String, Any)]
     attributes.foreach { attr =>
       kvPairs += (UTF8String.fromString(options.attributePrefix + 
attr.getName.getLocalPart)
-        -> convertTo(attr.getValue, valueType, options))
+        -> convertTo(attr.getValue, valueType))
     }
     var shouldStop = false
     while (!shouldStop) {
@@ -231,7 +256,7 @@ class StaxXmlParser(
         case e: StartElement =>
           kvPairs +=
             
(UTF8String.fromString(StaxXmlParserUtils.getName(e.asStartElement.getName, 
options)) ->
-             convertField(parser, valueType, options))
+             convertField(parser, valueType))
         case _: EndElement =>
           shouldStop = StaxXmlParserUtils.checkEndElement(parser)
         case _ => // do nothing
@@ -245,14 +270,13 @@ class StaxXmlParser(
    */
   private def convertAttributes(
       attributes: Array[Attribute],
-      schema: StructType,
-      options: XmlOptions): Map[String, Any] = {
+      schema: StructType): Map[String, Any] = {
     val convertedValuesMap = collection.mutable.Map.empty[String, Any]
     val valuesMap = 
StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options)
     valuesMap.foreach { case (f, v) =>
       val nameToIndex = schema.map(_.name).zipWithIndex.toMap
       nameToIndex.get(f).foreach { i =>
-        convertedValuesMap(f) = convertTo(v, schema(i).dataType, options)
+        convertedValuesMap(f) = convertTo(v, schema(i).dataType)
       }
     }
     convertedValuesMap.toMap
@@ -266,16 +290,15 @@ class StaxXmlParser(
   private def convertObjectWithAttributes(
       parser: XMLEventReader,
       schema: StructType,
-      options: XmlOptions,
       attributes: Array[Attribute] = Array.empty): InternalRow = {
     // TODO: This method might have to be removed. Some logics duplicate 
`convertObject()`
     val row = new Array[Any](schema.length)
 
     // Read attributes first.
-    val attributesMap = convertAttributes(attributes, schema, options)
+    val attributesMap = convertAttributes(attributes, schema)
 
     // Then, we read elements here.
-    val fieldsMap = convertField(parser, schema, options) match {
+    val fieldsMap = convertField(parser, schema) match {
       case internalRow: InternalRow =>
         Map(schema.map(_.name).zip(internalRow.toSeq(schema)): _*)
       case v if schema.fieldNames.contains(options.valueTag) =>
@@ -309,13 +332,12 @@ class StaxXmlParser(
   private def convertObject(
       parser: XMLEventReader,
       schema: StructType,
-      options: XmlOptions,
       rootAttributes: Array[Attribute] = Array.empty,
       isRootAttributesOnly: Boolean = false): InternalRow = {
     val row = new Array[Any](schema.length)
     val nameToIndex = schema.map(_.name).zipWithIndex.toMap
     // If there are attributes, then we process them first.
-    convertAttributes(rootAttributes, schema, options).toSeq.foreach { case 
(f, v) =>
+    convertAttributes(rootAttributes, schema).toSeq.foreach { case (f, v) =>
       nameToIndex.get(f).foreach { row(_) = v }
     }
 
@@ -334,7 +356,7 @@ class StaxXmlParser(
           nameToIndex.get(field) match {
             case Some(index) => schema(index).dataType match {
               case st: StructType =>
-                row(index) = convertObjectWithAttributes(parser, st, options, 
attributes)
+                row(index) = convertObjectWithAttributes(parser, st, 
attributes)
 
               case ArrayType(dt: DataType, _) =>
                 val values = Option(row(index))
@@ -342,21 +364,21 @@ class StaxXmlParser(
                   .getOrElse(ArrayBuffer.empty[Any])
                 val newValue = dt match {
                   case st: StructType =>
-                    convertObjectWithAttributes(parser, st, options, 
attributes)
+                    convertObjectWithAttributes(parser, st, attributes)
                   case dt: DataType =>
-                    convertField(parser, dt, options)
+                    convertField(parser, dt)
                 }
                 row(index) = values :+ newValue
 
               case dt: DataType =>
-                row(index) = convertField(parser, dt, options, attributes)
+                row(index) = convertField(parser, dt, attributes)
             }
 
             case None =>
               if (hasWildcard) {
                 // Special case: there's an 'any' wildcard element that 
matches anything else
                 // as a string (or array of strings, to parse multiple ones)
-                val newValue = convertField(parser, StringType, options)
+                val newValue = convertField(parser, StringType)
                 val anyIndex = schema.fieldIndex(wildcardColName)
                 schema(wildcardColName).dataType match {
                   case StringType =>
@@ -380,7 +402,7 @@ class StaxXmlParser(
         case c: Characters if !c.isWhiteSpace && isRootAttributesOnly =>
           nameToIndex.get(options.valueTag) match {
             case Some(index) =>
-              row(index) = convertTo(c.getData, schema(index).dataType, 
options)
+              row(index) = convertTo(c.getData, schema(index).dataType)
             case None => // do nothing
           }
 
@@ -410,6 +432,144 @@ class StaxXmlParser(
         badRecordException.get)
     }
   }
+
+  /**
+   * Casts given string datum to specified type.
+   *
+   * For string types, this is simply the datum.
+   * For other nullable types, returns null if it is null or equals to the 
value specified
+   * in `nullValue` option.
+   *
+   * @param datum    string value
+   * @param castType SparkSQL type
+   */
+  private def castTo(
+      datum: String,
+      castType: DataType): Any = {
+    if ((datum == options.nullValue) ||
+      (options.treatEmptyValuesAsNulls && datum == "")) {
+      null
+    } else {
+      castType match {
+        case _: ByteType => datum.toByte
+        case _: ShortType => datum.toShort
+        case _: IntegerType => datum.toInt
+        case _: LongType => datum.toLong
+        case _: FloatType => Try(datum.toFloat)
+          
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue())
+        case _: DoubleType => Try(datum.toDouble)
+          
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
+        case _: BooleanType => parseXmlBoolean(datum)
+        case dt: DecimalType =>
+          Decimal(decimalParser(datum), dt.precision, dt.scale)
+        case _: TimestampType => parseXmlTimestamp(datum, options)
+        case _: DateType => parseXmlDate(datum, options)
+        case _: StringType => UTF8String.fromString(datum)
+        case _ => throw new IllegalArgumentException(s"Unsupported type: 
${castType.typeName}")
+      }
+    }
+  }
+
+  private def parseXmlBoolean(s: String): Boolean = {
+    s.toLowerCase(Locale.ROOT) match {
+      case "true" | "1" => true
+      case "false" | "0" => false
+      case _ => throw new IllegalArgumentException(s"For input string: $s")
+    }
+  }
+
+  private def parseXmlDate(value: String, options: XmlOptions): Int = {
+    dateFormatter.parse(value)
+  }
+
+  private def parseXmlTimestamp(value: String, options: XmlOptions): Long = {
+    timestampFormatter.parse(value)
+  }
+
+  // TODO: This function unnecessarily does type dispatch. Should merge it 
with `castTo`.
+  private def convertTo(
+      datum: String,
+      dataType: DataType): Any = {
+    val value = if (datum != null && options.ignoreSurroundingSpaces) {
+      datum.trim()
+    } else {
+      datum
+    }
+    if ((value == options.nullValue) ||
+      (options.treatEmptyValuesAsNulls && value == "")) {
+      null
+    } else {
+      dataType match {
+        case NullType => castTo(value, StringType)
+        case LongType => signSafeToLong(value)
+        case DoubleType => signSafeToDouble(value)
+        case BooleanType => castTo(value, BooleanType)
+        case StringType => castTo(value, StringType)
+        case DateType => castTo(value, DateType)
+        case TimestampType => castTo(value, TimestampType)
+        case FloatType => signSafeToFloat(value)
+        case ByteType => castTo(value, ByteType)
+        case ShortType => castTo(value, ShortType)
+        case IntegerType => signSafeToInt(value)
+        case dt: DecimalType => castTo(value, dt)
+        case _ => throw new IllegalArgumentException(
+          s"Failed to parse a value for data type $dataType.")
+      }
+    }
+  }
+
+
+  private def signSafeToLong(value: String): Long = {
+    if (value.startsWith("+")) {
+      val data = value.substring(1)
+      castTo(data, LongType).asInstanceOf[Long]
+    } else if (value.startsWith("-")) {
+      val data = value.substring(1)
+      -castTo(data, LongType).asInstanceOf[Long]
+    } else {
+      val data = value
+      castTo(data, LongType).asInstanceOf[Long]
+    }
+  }
+
+  private def signSafeToDouble(value: String): Double = {
+    if (value.startsWith("+")) {
+      val data = value.substring(1)
+      castTo(data, DoubleType).asInstanceOf[Double]
+    } else if (value.startsWith("-")) {
+      val data = value.substring(1)
+      -castTo(data, DoubleType).asInstanceOf[Double]
+    } else {
+      val data = value
+      castTo(data, DoubleType).asInstanceOf[Double]
+    }
+  }
+
+  private def signSafeToInt(value: String): Int = {
+    if (value.startsWith("+")) {
+      val data = value.substring(1)
+      castTo(data, IntegerType).asInstanceOf[Int]
+    } else if (value.startsWith("-")) {
+      val data = value.substring(1)
+      -castTo(data, IntegerType).asInstanceOf[Int]
+    } else {
+      val data = value
+      castTo(data, IntegerType).asInstanceOf[Int]
+    }
+  }
+
+  private def signSafeToFloat(value: String): Float = {
+    if (value.startsWith("+")) {
+      val data = value.substring(1)
+      castTo(data, FloatType).asInstanceOf[Float]
+    } else if (value.startsWith("-")) {
+      val data = value.substring(1)
+      -castTo(data, FloatType).asInstanceOf[Float]
+    } else {
+      val data = value
+      castTo(data, FloatType).asInstanceOf[Float]
+    }
+  }
 }
 
 /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/TypeCast.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/TypeCast.scala
deleted file mode 100644
index 3315196ffc76..000000000000
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/TypeCast.scala
+++ /dev/null
@@ -1,244 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.catalyst.xml
-
-import java.math.BigDecimal
-import java.text.NumberFormat
-import java.util.Locale
-
-import scala.util.Try
-import scala.util.control.Exception._
-
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-
-/**
- * Utility functions for type casting
- */
-private[sql] object TypeCast {
-
-  /**
-   * Casts given string datum to specified type.
-   * Currently we do not support complex types (ArrayType, MapType, 
StructType).
-   *
-   * For string types, this is simply the datum. For other types.
-   * For other nullable types, this is null if the string datum is empty.
-   *
-   * @param datum string value
-   * @param castType SparkSQL type
-   */
-  private[sql] def castTo(
-      datum: String,
-      castType: DataType,
-      options: XmlOptions): Any = {
-    if ((datum == options.nullValue) ||
-        (options.treatEmptyValuesAsNulls && datum == "")) {
-      null
-    } else {
-      castType match {
-        case _: ByteType => datum.toByte
-        case _: ShortType => datum.toShort
-        case _: IntegerType => datum.toInt
-        case _: LongType => datum.toLong
-        case _: FloatType => Try(datum.toFloat)
-          
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue())
-        case _: DoubleType => Try(datum.toDouble)
-          
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
-        case _: BooleanType => parseXmlBoolean(datum)
-        case dt: DecimalType =>
-          Decimal(new BigDecimal(datum.replaceAll(",", "")), dt.precision, 
dt.scale)
-        case _: TimestampType => parseXmlTimestamp(datum, options)
-        case _: DateType => parseXmlDate(datum, options)
-        case _: StringType => UTF8String.fromString(datum)
-        case _ => throw new IllegalArgumentException(s"Unsupported type: 
${castType.typeName}")
-      }
-    }
-  }
-
-  private def parseXmlBoolean(s: String): Boolean = {
-    s.toLowerCase(Locale.ROOT) match {
-      case "true" | "1" => true
-      case "false" | "0" => false
-      case _ => throw new IllegalArgumentException(s"For input string: $s")
-    }
-  }
-
-  private def parseXmlDate(value: String, options: XmlOptions): Int = {
-    options.dateFormatter.parse(value)
-  }
-
-  private def parseXmlTimestamp(value: String, options: XmlOptions): Long = {
-    options.timestampFormatter.parse(value)
-  }
-
-  // TODO: This function unnecessarily does type dispatch. Should merge it 
with `castTo`.
-  private[sql] def convertTo(
-      datum: String,
-      dataType: DataType,
-      options: XmlOptions): Any = {
-    val value = if (datum != null && options.ignoreSurroundingSpaces) {
-      datum.trim()
-    } else {
-      datum
-    }
-    if ((value == options.nullValue) ||
-      (options.treatEmptyValuesAsNulls && value == "")) {
-      null
-    } else {
-      dataType match {
-        case NullType => castTo(value, StringType, options)
-        case LongType => signSafeToLong(value, options)
-        case DoubleType => signSafeToDouble(value, options)
-        case BooleanType => castTo(value, BooleanType, options)
-        case StringType => castTo(value, StringType, options)
-        case DateType => castTo(value, DateType, options)
-        case TimestampType => castTo(value, TimestampType, options)
-        case FloatType => signSafeToFloat(value, options)
-        case ByteType => castTo(value, ByteType, options)
-        case ShortType => castTo(value, ShortType, options)
-        case IntegerType => signSafeToInt(value, options)
-        case dt: DecimalType => castTo(value, dt, options)
-        case _ => throw new IllegalArgumentException(
-          s"Failed to parse a value for data type $dataType.")
-      }
-    }
-  }
-
-  /**
-   * Helper method that checks and cast string representation of a numeric 
types.
-   */
-  private[sql] def isBoolean(value: String): Boolean = {
-    value.toLowerCase(Locale.ROOT) match {
-      case "true" | "false" => true
-      case _ => false
-    }
-  }
-
-  private[sql] def isDouble(value: String): Boolean = {
-    val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) {
-      value.substring(1)
-    } else {
-      value
-    }
-    // A little shortcut to avoid trying many formatters in the common case 
that
-    // the input isn't a double. All built-in formats will start with a digit 
or period.
-    if (signSafeValue.isEmpty ||
-      !(Character.isDigit(signSafeValue.head) || signSafeValue.head == '.')) {
-      return false
-    }
-    // Rule out strings ending in D or F, as they will parse as double but 
should be disallowed
-    if (value.nonEmpty && (value.last match {
-          case 'd' | 'D' | 'f' | 'F' => true
-          case _ => false
-        })) {
-      return false
-    }
-    (allCatch opt signSafeValue.toDouble).isDefined
-  }
-
-  private[sql] def isInteger(value: String): Boolean = {
-    val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) {
-      value.substring(1)
-    } else {
-      value
-    }
-    // A little shortcut to avoid trying many formatters in the common case 
that
-    // the input isn't a number. All built-in formats will start with a digit.
-    if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) {
-      return false
-    }
-    (allCatch opt signSafeValue.toInt).isDefined
-  }
-
-  private[sql] def isLong(value: String): Boolean = {
-    val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) {
-      value.substring(1)
-    } else {
-      value
-    }
-    // A little shortcut to avoid trying many formatters in the common case 
that
-    // the input isn't a number. All built-in formats will start with a digit.
-    if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) {
-      return false
-    }
-    (allCatch opt signSafeValue.toLong).isDefined
-  }
-
-  private[sql] def isTimestamp(value: String, options: XmlOptions): Boolean = {
-    try {
-      options.timestampFormatter.parseOptional(value).isDefined
-    } catch {
-      case _: IllegalArgumentException => false
-    }
-  }
-
-  private[sql] def isDate(value: String, options: XmlOptions): Boolean = {
-    (allCatch opt options.dateFormatter.parse(value)).isDefined
-  }
-
-  private[sql] def signSafeToLong(value: String, options: XmlOptions): Long = {
-    if (value.startsWith("+")) {
-      val data = value.substring(1)
-      TypeCast.castTo(data, LongType, options).asInstanceOf[Long]
-    } else if (value.startsWith("-")) {
-      val data = value.substring(1)
-      -TypeCast.castTo(data, LongType, options).asInstanceOf[Long]
-    } else {
-      val data = value
-      TypeCast.castTo(data, LongType, options).asInstanceOf[Long]
-    }
-  }
-
-  private[sql] def signSafeToDouble(value: String, options: XmlOptions): 
Double = {
-    if (value.startsWith("+")) {
-      val data = value.substring(1)
-      TypeCast.castTo(data, DoubleType, options).asInstanceOf[Double]
-    } else if (value.startsWith("-")) {
-      val data = value.substring(1)
-     -TypeCast.castTo(data, DoubleType, options).asInstanceOf[Double]
-    } else {
-      val data = value
-      TypeCast.castTo(data, DoubleType, options).asInstanceOf[Double]
-    }
-  }
-
-  private[sql] def signSafeToInt(value: String, options: XmlOptions): Int = {
-    if (value.startsWith("+")) {
-      val data = value.substring(1)
-      TypeCast.castTo(data, IntegerType, options).asInstanceOf[Int]
-    } else if (value.startsWith("-")) {
-      val data = value.substring(1)
-      -TypeCast.castTo(data, IntegerType, options).asInstanceOf[Int]
-    } else {
-      val data = value
-      TypeCast.castTo(data, IntegerType, options).asInstanceOf[Int]
-    }
-  }
-
-  private[sql] def signSafeToFloat(value: String, options: XmlOptions): Float 
= {
-    if (value.startsWith("+")) {
-      val data = value.substring(1)
-      TypeCast.castTo(data, FloatType, options).asInstanceOf[Float]
-    } else if (value.startsWith("-")) {
-      val data = value.substring(1)
-      -TypeCast.castTo(data, FloatType, options).asInstanceOf[Float]
-    } else {
-      val data = value
-      TypeCast.castTo(data, FloatType, options).asInstanceOf[Float]
-    }
-  }
-}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
index 8bddb8f5bd99..777dd69fd7fa 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.sql.catalyst.xml
 
 import java.io.StringReader
+import java.util.Locale
 import javax.xml.stream.XMLEventReader
 import javax.xml.stream.events._
 import javax.xml.transform.stream.StreamSource
@@ -25,14 +26,39 @@ import javax.xml.validation.Schema
 import scala.annotation.tailrec
 import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
+import scala.util.control.Exception._
 import scala.util.control.NonFatal
 
+import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.util.PermissiveMode
-import org.apache.spark.sql.catalyst.xml.TypeCast._
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
+import org.apache.spark.sql.catalyst.util.{DateFormatter, PermissiveMode, 
TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
 import org.apache.spark.sql.types._
 
-private[sql] object XmlInferSchema {
+private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable 
with Logging {
+
+  private val decimalParser = ExprUtils.getDecimalParser(options.locale)
+
+  private val timestampFormatter = TimestampFormatter(
+    options.timestampFormatInRead,
+    options.zoneId,
+    options.locale,
+    legacyFormat = FAST_DATE_FORMAT,
+    isParsing = true)
+
+  private val timestampNTZFormatter = TimestampFormatter(
+    options.timestampNTZFormatInRead,
+    options.zoneId,
+    legacyFormat = FAST_DATE_FORMAT,
+    isParsing = true,
+    forTimestampNTZ = true)
+
+  private lazy val dateFormatter = DateFormatter(
+    options.dateFormatInRead,
+    options.locale,
+    legacyFormat = FAST_DATE_FORMAT,
+    isParsing = true)
 
   /**
    * Copied from internal Spark api
@@ -66,7 +92,7 @@ private[sql] object XmlInferSchema {
    *   2. Merge types by choosing the lowest type necessary to cover equal keys
    *   3. Replace any remaining null fields with string, the top type
    */
-  def infer(xml: RDD[String], options: XmlOptions): StructType = {
+  def infer(xml: RDD[String]): StructType = {
     val schemaData = if (options.samplingRatio < 1.0) {
       xml.sample(withReplacement = false, options.samplingRatio, 1)
     } else {
@@ -77,9 +103,9 @@ private[sql] object XmlInferSchema {
       val xsdSchema = 
Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)
 
       iter.flatMap { xml =>
-        infer(xml, options, xsdSchema)
+        infer(xml, xsdSchema)
       }
-    }.fold(StructType(Seq()))(compatibleType(options))
+    }.fold(StructType(Seq()))(compatibleType)
 
     canonicalizeType(rootType) match {
       case Some(st: StructType) => st
@@ -90,7 +116,6 @@ private[sql] object XmlInferSchema {
   }
 
   def infer(xml: String,
-      options: XmlOptions,
       xsdSchema: Option[Schema] = None): Option[DataType] = {
     try {
       val xsd = 
xsdSchema.orElse(Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema))
@@ -99,7 +124,7 @@ private[sql] object XmlInferSchema {
       }
       val parser = StaxXmlParserUtils.filteredReader(xml)
       val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
-      Some(inferObject(parser, options, rootAttributes))
+      Some(inferObject(parser, rootAttributes))
     } catch {
       case NonFatal(_) if options.parseMode == PermissiveMode =>
         Some(StructType(Seq(StructField(options.columnNameOfCorruptRecord, 
StringType))))
@@ -108,7 +133,7 @@ private[sql] object XmlInferSchema {
     }
   }
 
-  private def inferFrom(datum: String, options: XmlOptions): DataType = {
+  private def inferFrom(datum: String): DataType = {
     val value = if (datum != null && options.ignoreSurroundingSpaces) {
       datum.trim()
     } else {
@@ -123,8 +148,8 @@ private[sql] object XmlInferSchema {
         case v if isInteger(v) => IntegerType
         case v if isDouble(v) => DoubleType
         case v if isBoolean(v) => BooleanType
-        case v if isDate(v, options) => DateType
-        case v if isTimestamp(v, options) => TimestampType
+        case v if isDate(v) => DateType
+        case v if isTimestamp(v) => TimestampType
         case _ => StringType
       }
     } else {
@@ -133,32 +158,32 @@ private[sql] object XmlInferSchema {
   }
 
   @tailrec
-  private def inferField(parser: XMLEventReader, options: XmlOptions): 
DataType = {
+  private def inferField(parser: XMLEventReader): DataType = {
     parser.peek match {
       case _: EndElement => NullType
-      case _: StartElement => inferObject(parser, options)
+      case _: StartElement => inferObject(parser)
       case c: Characters if c.isWhiteSpace =>
         // When `Characters` is found, we need to look further to decide
         // if this is really data or space between other elements.
         val data = c.getData
         parser.nextEvent()
         parser.peek match {
-          case _: StartElement => inferObject(parser, options)
+          case _: StartElement => inferObject(parser)
           case _: EndElement if data.isEmpty => NullType
           case _: EndElement if options.treatEmptyValuesAsNulls => NullType
           case _: EndElement => StringType
-          case _ => inferField(parser, options)
+          case _ => inferField(parser)
         }
       case c: Characters if !c.isWhiteSpace =>
         // This could be the characters of a character-only element, or could 
have mixed
         // characters and other complex structure
-        val characterType = inferFrom(c.getData, options)
+        val characterType = inferFrom(c.getData)
         parser.nextEvent()
         parser.peek match {
           case _: StartElement =>
             // Some more elements follow; so ignore the characters.
             // Use the schema of the rest
-            inferObject(parser, options).asInstanceOf[StructType]
+            inferObject(parser).asInstanceOf[StructType]
           case _ =>
             // That's all, just the character-only body; use that as the type
             characterType
@@ -173,7 +198,6 @@ private[sql] object XmlInferSchema {
    */
   private def inferObject(
       parser: XMLEventReader,
-      options: XmlOptions,
       rootAttributes: Array[Attribute] = Array.empty): DataType = {
     val builder = ArrayBuffer[StructField]()
     val nameToDataType = collection.mutable.Map.empty[String, 
ArrayBuffer[DataType]]
@@ -182,7 +206,7 @@ private[sql] object XmlInferSchema {
       StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options)
     rootValuesMap.foreach {
       case (f, v) =>
-        nameToDataType += (f -> ArrayBuffer(inferFrom(v, options)))
+        nameToDataType += (f -> ArrayBuffer(inferFrom(v)))
     }
     var shouldStop = false
     while (!shouldStop) {
@@ -190,14 +214,14 @@ private[sql] object XmlInferSchema {
         case e: StartElement =>
           val attributes = 
e.getAttributes.asScala.map(_.asInstanceOf[Attribute]).toArray
           val valuesMap = 
StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options)
-          val inferredType = inferField(parser, options) match {
+          val inferredType = inferField(parser) match {
             case st: StructType if valuesMap.nonEmpty =>
               // Merge attributes to the field
               val nestedBuilder = ArrayBuffer[StructField]()
               nestedBuilder ++= st.fields
               valuesMap.foreach {
                 case (f, v) =>
-                  nestedBuilder += StructField(f, inferFrom(v, options), 
nullable = true)
+                  nestedBuilder += StructField(f, inferFrom(v), nullable = 
true)
               }
               StructType(nestedBuilder.sortBy(_.name).toArray)
 
@@ -207,7 +231,7 @@ private[sql] object XmlInferSchema {
               nestedBuilder += StructField(options.valueTag, dt, nullable = 
true)
               valuesMap.foreach {
                 case (f, v) =>
-                  nestedBuilder += StructField(f, inferFrom(v, options), 
nullable = true)
+                  nestedBuilder += StructField(f, inferFrom(v), nullable = 
true)
               }
               StructType(nestedBuilder.sortBy(_.name).toArray)
 
@@ -221,7 +245,7 @@ private[sql] object XmlInferSchema {
 
         case c: Characters if !c.isWhiteSpace =>
           // This can be an attribute-only object
-          val valueTagType = inferFrom(c.getData, options)
+          val valueTagType = inferFrom(c.getData)
           nameToDataType += options.valueTag -> ArrayBuffer(valueTagType)
 
         case _: EndElement =>
@@ -245,7 +269,7 @@ private[sql] object XmlInferSchema {
     // This can be inferred as ArrayType.
     nameToDataType.foreach {
       case (field, dataTypes) if dataTypes.length > 1 =>
-        val elementType = 
dataTypes.reduceLeft(XmlInferSchema.compatibleType(options))
+        val elementType = dataTypes.reduceLeft(compatibleType)
         builder += StructField(field, ArrayType(elementType), nullable = true)
       case (field, dataTypes) =>
         builder += StructField(field, dataTypes.head, nullable = true)
@@ -255,6 +279,78 @@ private[sql] object XmlInferSchema {
     StructType(builder.sortBy(_.name).toArray)
   }
 
+  /**
+   * Helper method that checks and cast string representation of a numeric 
types.
+   */
+  private def isBoolean(value: String): Boolean = {
+    value.toLowerCase(Locale.ROOT) match {
+      case "true" | "false" => true
+      case _ => false
+    }
+  }
+
+  private def isDouble(value: String): Boolean = {
+    val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) {
+      value.substring(1)
+    } else {
+      value
+    }
+    // A little shortcut to avoid trying many formatters in the common case 
that
+    // the input isn't a double. All built-in formats will start with a digit 
or period.
+    if (signSafeValue.isEmpty ||
+      !(Character.isDigit(signSafeValue.head) || signSafeValue.head == '.')) {
+      return false
+    }
+    // Rule out strings ending in D or F, as they will parse as double but 
should be disallowed
+    if (value.nonEmpty && (value.last match {
+      case 'd' | 'D' | 'f' | 'F' => true
+      case _ => false
+    })) {
+      return false
+    }
+    (allCatch opt signSafeValue.toDouble).isDefined
+  }
+
+  private def isInteger(value: String): Boolean = {
+    val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) {
+      value.substring(1)
+    } else {
+      value
+    }
+    // A little shortcut to avoid trying many formatters in the common case 
that
+    // the input isn't a number. All built-in formats will start with a digit.
+    if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) {
+      return false
+    }
+    (allCatch opt signSafeValue.toInt).isDefined
+  }
+
+  private def isLong(value: String): Boolean = {
+    val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) {
+      value.substring(1)
+    } else {
+      value
+    }
+    // A little shortcut to avoid trying many formatters in the common case 
that
+    // the input isn't a number. All built-in formats will start with a digit.
+    if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) {
+      return false
+    }
+    (allCatch opt signSafeValue.toLong).isDefined
+  }
+
+  private def isTimestamp(value: String): Boolean = {
+    try {
+      timestampFormatter.parseOptional(value).isDefined
+    } catch {
+      case _: IllegalArgumentException => false
+    }
+  }
+
+  private def isDate(value: String): Boolean = {
+    (allCatch opt dateFormatter.parse(value)).isDefined
+  }
+
   /**
    * Convert NullType to StringType and remove StructTypes with no fields
    */
@@ -288,7 +384,7 @@ private[sql] object XmlInferSchema {
   /**
    * Returns the most general data type for two given data types.
    */
-  private[xml] def compatibleType(options: XmlOptions)(t1: DataType, t2: 
DataType): DataType = {
+  def compatibleType(t1: DataType, t2: DataType): DataType = {
     // TODO: Optimise this logic.
     findTightestCommonTypeOfTwo(t1, t2).getOrElse {
       // t1 or t2 is a StructType, ArrayType, or an unexpected type.
@@ -312,22 +408,22 @@ private[sql] object XmlInferSchema {
         case (StructType(fields1), StructType(fields2)) =>
           val newFields = (fields1 ++ fields2).groupBy(_.name).map {
             case (name, fieldTypes) =>
-              val dataType = 
fieldTypes.map(_.dataType).reduce(compatibleType(options))
+              val dataType = fieldTypes.map(_.dataType).reduce(compatibleType)
               StructField(name, dataType, nullable = true)
           }
           StructType(newFields.toArray.sortBy(_.name))
 
         case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, 
containsNull2)) =>
           ArrayType(
-            compatibleType(options)(elementType1, elementType2), containsNull1 
|| containsNull2)
+            compatibleType(elementType1, elementType2), containsNull1 || 
containsNull2)
 
         // In XML datasource, since StructType can be compared with ArrayType.
         // In this case, ArrayType wraps the StructType.
         case (ArrayType(ty1, _), ty2) =>
-          ArrayType(compatibleType(options)(ty1, ty2))
+          ArrayType(compatibleType(ty1, ty2))
 
         case (ty1, ArrayType(ty2, _)) =>
-          ArrayType(compatibleType(options)(ty1, ty2))
+          ArrayType(compatibleType(ty1, ty2))
 
         // As this library can infer an element with attributes as StructType 
whereas
         // some can be inferred as other non-structural data types, this case 
should be
@@ -335,14 +431,14 @@ private[sql] object XmlInferSchema {
         case (st: StructType, dt: DataType) if 
st.fieldNames.contains(options.valueTag) =>
           val valueIndex = st.fieldNames.indexOf(options.valueTag)
           val valueField = st.fields(valueIndex)
-          val valueDataType = compatibleType(options)(valueField.dataType, dt)
+          val valueDataType = compatibleType(valueField.dataType, dt)
           st.fields(valueIndex) = StructField(options.valueTag, valueDataType, 
nullable = true)
           st
 
         case (dt: DataType, st: StructType) if 
st.fieldNames.contains(options.valueTag) =>
           val valueIndex = st.fieldNames.indexOf(options.valueTag)
           val valueField = st.fields(valueIndex)
-          val valueDataType = compatibleType(options)(dt, valueField.dataType)
+          val valueDataType = compatibleType(dt, valueField.dataType)
           st.fields(valueIndex) = StructField(options.valueTag, valueDataType, 
nullable = true)
           st
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
index 7d049fdd82b8..aac6eec21c60 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
@@ -23,8 +23,7 @@ import javax.xml.stream.XMLInputFactory
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions}
-import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, 
CompressionCodecs, DateFormatter, DateTimeUtils, ParseMode, PermissiveMode, 
TimestampFormatter}
-import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, 
CompressionCodecs, DateFormatter, DateTimeUtils, ParseMode, PermissiveMode}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
 
@@ -32,7 +31,7 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, 
SQLConf}
  * Options for the XML data source.
  */
 private[sql] class XmlOptions(
-    @transient val parameters: CaseInsensitiveMap[String],
+    val parameters: CaseInsensitiveMap[String],
     defaultTimeZoneId: String,
     defaultColumnNameOfCorruptRecord: String,
     rowTagRequired: Boolean)
@@ -147,6 +146,10 @@ private[sql] class XmlOptions(
       s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS][XXX]"
     })
 
+  val timestampNTZFormatInRead: Option[String] = 
parameters.get(TIMESTAMP_NTZ_FORMAT)
+  val timestampNTZFormatInWrite: String =
+    parameters.getOrElse(TIMESTAMP_NTZ_FORMAT, 
s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]")
+
   val timezone = parameters.get("timezone")
 
   val zoneId: ZoneId = DateTimeUtils.getZoneId(
@@ -163,32 +166,6 @@ private[sql] class XmlOptions(
   def buildXmlFactory(): XMLInputFactory = {
     XMLInputFactory.newInstance()
   }
-
-  val timestampFormatter = TimestampFormatter(
-    timestampFormatInRead,
-    zoneId,
-    locale,
-    legacyFormat = FAST_DATE_FORMAT,
-    isParsing = true)
-
-  val timestampFormatterInWrite = TimestampFormatter(
-    timestampFormatInWrite,
-    zoneId,
-    locale,
-    legacyFormat = FAST_DATE_FORMAT,
-    isParsing = false)
-
-  val dateFormatter = DateFormatter(
-    dateFormatInRead,
-    locale,
-    legacyFormat = FAST_DATE_FORMAT,
-    isParsing = true)
-
-  val dateFormatterInWrite = DateFormatter(
-    dateFormatInWrite,
-    locale,
-    legacyFormat = FAST_DATE_FORMAT,
-    isParsing = false)
 }
 
 private[sql] object XmlOptions extends DataSourceOptions {
@@ -225,6 +202,7 @@ private[sql] object XmlOptions extends DataSourceOptions {
   val COLUMN_NAME_OF_CORRUPT_RECORD = newOption("columnNameOfCorruptRecord")
   val DATE_FORMAT = newOption("dateFormat")
   val TIMESTAMP_FORMAT = newOption("timestampFormat")
+  val TIMESTAMP_NTZ_FORMAT = newOption("timestampNTZFormat")
   val TIME_ZONE = newOption("timeZone")
   val INDENT = newOption("indent")
   // Options with alternative
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index bc62003b251e..9992d8cbba07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -565,7 +565,10 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    * @since 4.0.0
    */
   @scala.annotation.varargs
-  def xml(paths: String*): DataFrame = format("xml").load(paths: _*)
+  def xml(paths: String*): DataFrame = {
+    userSpecifiedSchema.foreach(checkXmlSchema)
+    format("xml").load(paths: _*)
+  }
 
   /**
    * Loads an `Dataset[String]` storing XML object and returns the result as a 
`DataFrame`.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala
index d96d80d6ce51..b09be84130ab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala
@@ -121,7 +121,9 @@ object TextInputXmlDataSource extends XmlDataSource {
   def inferFromDataset(
       xml: Dataset[String],
       parsedOptions: XmlOptions): StructType = {
-    XmlInferSchema.infer(xml.rdd, parsedOptions)
+    SQLExecution.withSQLConfPropagated(xml.sparkSession) {
+      new XmlInferSchema(parsedOptions).infer(xml.rdd)
+    }
   }
 
   private def createBaseDataset(
@@ -177,7 +179,7 @@ object MultiLineXmlDataSource extends XmlDataSource {
         parsedOptions)
     }
     SQLExecution.withSQLConfPropagated(sparkSession) {
-      val schema = XmlInferSchema.infer(tokenRDD, parsedOptions)
+      val schema = new XmlInferSchema(parsedOptions).infer(tokenRDD)
       schema
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index fc8f5a416ab1..36dd168992a1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
 import org.apache.spark.sql.execution.datasources.DataSource
 import 
org.apache.spark.sql.execution.datasources.json.JsonUtils.checkJsonSchema
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, 
FileDataSourceV2}
+import org.apache.spark.sql.execution.datasources.xml.XmlUtils.checkXmlSchema
 import org.apache.spark.sql.execution.streaming.StreamingRelation
 import org.apache.spark.sql.sources.StreamSourceProvider
 import org.apache.spark.sql.types.StructType
@@ -278,7 +279,10 @@ final class DataStreamReader private[sql](sparkSession: 
SparkSession) extends Lo
    *
    * @since 4.0.0
    */
-  def xml(path: String): DataFrame = format("xml").load(path)
+  def xml(path: String): DataFrame = {
+    userSpecifiedSchema.foreach(checkXmlSchema)
+    format("xml").load(path)
+  }
 
   /**
    * Loads a ORC file stream, returning the result as a `DataFrame`.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
index 20600848019d..2d4cd2f403c5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.xml
 import java.nio.charset.{StandardCharsets, UnsupportedCharsetException}
 import java.nio.file.{Files, Path, Paths}
 import java.sql.{Date, Timestamp}
+import java.time.Instant
 import java.util.TimeZone
 
 import scala.collection.mutable
@@ -31,15 +32,15 @@ import org.apache.hadoop.io.{LongWritable, Text}
 import org.apache.hadoop.io.compress.GzipCodec
 
 import org.apache.spark.SparkException
-import org.apache.spark.sql.{AnalysisException, Encoders, QueryTest, Row, 
SaveMode}
+import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, QueryTest, 
Row, SaveMode}
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.xml.XmlOptions
 import org.apache.spark.sql.catalyst.xml.XmlOptions._
 import org.apache.spark.sql.execution.datasources.xml.TestUtils._
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 class XmlSuite extends QueryTest with SharedSparkSession {
   import testImplicits._
@@ -48,9 +49,6 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
   private var tempDir: Path = _
 
-  protected override def sparkConf = super.sparkConf
-    .set(SQLConf.SESSION_LOCAL_TIMEZONE.key, "UTC")
-
   override protected def beforeAll(): Unit = {
     super.beforeAll()
     tempDir = Files.createTempDirectory("XmlSuite")
@@ -1511,7 +1509,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     val expectedSchema =
       buildSchema(field("author"), field("date", TimestampType), 
field("date2", DateType))
     assert(df.schema === expectedSchema)
-    assert(df.collect().head.getAs[Timestamp](1).toString === "2021-01-31 
16:00:00.0")
+    assert(df.collect().head.getAs[Timestamp](1) === 
Timestamp.valueOf("2021-02-01 00:00:00"))
     assert(df.collect().head.getAs[Date](2).toString === "2021-02-01")
   }
 
@@ -1556,7 +1554,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     val res = df.collect()
     assert(res.head.get(1) === "2011-12-03T10:15:30Z")
     assert(res.head.get(2) === "12-03-2011 10:15:30 PST")
-    assert(res.head.getAs[Timestamp](3).getTime === 1322892930000L)
+    assert(res.head.getAs[Timestamp](3) === Timestamp.valueOf("2011-12-03 
06:15:30"))
   }
 
   test("Test custom timestampFormat with offset") {
@@ -1803,4 +1801,174 @@ class XmlSuite extends QueryTest with 
SharedSparkSession {
       "declaration" -> s"<${XmlOptions.DEFAULT_DECLARATION}>"),
       "'declaration' should not include angle brackets")
   }
+
+  def dataTypeTest(data: String,
+                   dt: DataType): Unit = {
+    val xmlString = s"""<ROW>$data</ROW>"""
+    val schema = new StructType().add(XmlOptions.VALUE_TAG, dt)
+    val df = spark.read
+      .option("rowTag", "ROW")
+      .schema(schema)
+      .xml(spark.createDataset(Seq(xmlString)))
+  }
+
+  test("Primitive field casting") {
+    val ts = Seq("2002-05-30 21:46:54", "2002-05-30T21:46:54", 
"2002-05-30T21:46:54.1234",
+      "2002-05-30T21:46:54Z", "2002-05-30T21:46:54.1234Z", 
"2002-05-30T21:46:54-06:00",
+      "2002-05-30T21:46:54+06:00", "2002-05-30T21:46:54.1234-06:00",
+      "2002-05-30T21:46:54.1234+06:00", "2002-05-30T21:46:54+00:00", 
"2002-05-30T21:46:54.0000Z")
+
+    val tsXMLStr = ts.map(t => 
s"<TimestampType>$t</TimestampType>").mkString("\n")
+    val tsResult = ts.map(t =>
+      Timestamp.from(Instant.ofEpochSecond(0, DateTimeUtils.stringToTimestamp(
+        UTF8String.fromString(t), 
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)).get * 1000))
+    )
+
+    val primitiveFieldAndType: Dataset[String] =
+      spark.createDataset(spark.sparkContext.parallelize(
+        s"""<ROW>
+          <decimal>10.05</decimal>
+          <decimal>1,000.01</decimal>
+          <decimal>158,058,049.001</decimal>
+          <emptyString></emptyString>
+          <ByteType>10</ByteType>
+          <ShortType>10</ShortType>
+          <ShortType>+10</ShortType>
+          <ShortType>-10</ShortType>
+          <IntegerType>10</IntegerType>
+          <IntegerType>+10</IntegerType>
+          <IntegerType>-10</IntegerType>
+          <LongType>10</LongType>
+          <LongType>+10</LongType>
+          <LongType>-10</LongType>
+          <FloatType>1.00</FloatType>
+          <FloatType>+1.00</FloatType>
+          <FloatType>-1.00</FloatType>
+          <DoubleType>1.00</DoubleType>
+          <DoubleType>+1.00</DoubleType>
+          <DoubleType>-1.00</DoubleType>
+          <BooleanType>true</BooleanType>
+          <BooleanType>1</BooleanType>
+          <BooleanType>false</BooleanType>
+          <BooleanType>0</BooleanType>
+          $tsXMLStr
+          <DateType>2002-09-24</DateType>
+        </ROW>""".stripMargin :: Nil))(Encoders.STRING)
+
+    val decimalType = DecimalType(20, 3)
+
+    val schema = StructType(
+      StructField("decimal", ArrayType(decimalType), true) ::
+        StructField("emptyString", StringType, true) ::
+        StructField("ByteType", ByteType, true) ::
+        StructField("ShortType", ArrayType(ShortType), true) ::
+        StructField("IntegerType", ArrayType(IntegerType), true) ::
+        StructField("LongType", ArrayType(LongType), true) ::
+        StructField("FloatType", ArrayType(FloatType), true) ::
+        StructField("DoubleType", ArrayType(DoubleType), true) ::
+        StructField("BooleanType", ArrayType(BooleanType), true) ::
+        StructField("TimestampType", ArrayType(TimestampType), true) ::
+        StructField("DateType", DateType, true) :: Nil)
+
+    val df = spark.read.schema(schema).xml(primitiveFieldAndType)
+
+    checkAnswer(
+      df,
+      Seq(Row(Array(
+        Decimal(BigDecimal("10.05"), decimalType.precision, 
decimalType.scale).toJavaBigDecimal,
+        Decimal(BigDecimal("1000.01"), decimalType.precision, 
decimalType.scale).toJavaBigDecimal,
+        Decimal(BigDecimal("158058049.001"), decimalType.precision, 
decimalType.scale)
+          .toJavaBigDecimal),
+        "",
+        10.toByte,
+        Array(10.toShort, 10.toShort, -10.toShort),
+        Array(10, 10, -10),
+        Array(10L, 10L, -10L),
+        Array(1.0.toFloat, 1.0.toFloat, -1.0.toFloat),
+        Array(1.0, 1.0, -1.0),
+        Array(true, true, false, false),
+        tsResult,
+        Date.valueOf("2002-09-24")
+      ))
+    )
+  }
+
+  test("Nullable types are handled") {
+    val dataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, 
DoubleType,
+      BooleanType, TimestampType, DateType, StringType)
+
+    val dataXMLString = dataTypes.map { dt =>
+      s"""<${dt.toString}>-</${dt.toString}>"""
+    }.mkString("\n")
+
+    val fields = dataTypes.map { dt =>
+      StructField(dt.toString, dt, true)
+    }
+    val schema = StructType(fields)
+
+    val res = dataTypes.map { dt => null }
+
+    val nullDataset: Dataset[String] =
+      spark.createDataset(spark.sparkContext.parallelize(
+        s"""<ROW>
+          $dataXMLString
+        </ROW>""".stripMargin :: Nil))(Encoders.STRING)
+
+    val df = spark.read.option("nullValue", 
"-").schema(schema).xml(nullDataset)
+    checkAnswer(df, Row.fromSeq(res))
+
+    val df2 = spark.read.xml(nullDataset)
+    checkAnswer(df2, Row.fromSeq(dataTypes.map { dt => "-" }))
+  }
+
+  test("Custom timestamp format is used to parse correctly") {
+    val schema = StructType(
+      StructField("ts", TimestampType, true) :: Nil)
+
+    Seq(
+      ("12-03-2011 10:15:30", "2011-12-03 10:15:30", "MM-dd-yyyy HH:mm:ss", 
"UTC"),
+      ("2011/12/03 10:15:30", "2011-12-03 10:15:30", "yyyy/MM/dd HH:mm:ss", 
"UTC"),
+      ("2011/12/03 10:15:30", "2011-12-03 10:15:30", "yyyy/MM/dd HH:mm:ss", 
"Asia/Shanghai")
+    ).foreach { case (ts, resTS, fmt, zone) =>
+      val tsDataset: Dataset[String] =
+        spark.createDataset(spark.sparkContext.parallelize(
+          s"""<ROW>
+          <ts>$ts</ts>
+        </ROW>""".stripMargin :: Nil))(Encoders.STRING)
+      val timestampResult = Timestamp.from(Instant.ofEpochSecond(0,
+        DateTimeUtils.stringToTimestamp(UTF8String.fromString(resTS),
+          DateTimeUtils.getZoneId(zone)).get * 1000))
+
+      val df = spark.read.option("timestampFormat", fmt).option("timezone", 
zone)
+        .schema(schema).xml(tsDataset)
+      checkAnswer(df, Row(timestampResult))
+    }
+  }
+
+  test("Schema Inference for primitive types") {
+    val dataset: Dataset[String] =
+      spark.createDataset(spark.sparkContext.parallelize(
+        s"""<ROW>
+          <bool1>true</bool1>
+          <double1>+10.1</double1>
+          <long1>-10</long1>
+          <long2>10</long2>
+          <string1>8E9D</string1>
+          <string2>8E9F</string2>
+          <ts1>2015-01-01 00:00:00</ts1>
+        </ROW>""".stripMargin :: Nil))(Encoders.STRING)
+
+    val expectedSchema = StructType(StructField("bool1", BooleanType, true) ::
+      StructField("double1", DoubleType, true) ::
+      StructField("long1", LongType, true) ::
+      StructField("long2", LongType, true) ::
+      StructField("string1", StringType, true) ::
+      StructField("string2", StringType, true) ::
+      StructField("ts1", TimestampType, true) :: Nil)
+
+    val df = spark.read.xml(dataset)
+    assert(df.schema.toSet === expectedSchema.toSet)
+    checkAnswer(df, Row(true, 10.1, -10, 10, "8E9D", "8E9F",
+      Timestamp.valueOf("2015-01-01 00:00:00")))
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/util/TypeCastSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/util/TypeCastSuite.scala
deleted file mode 100644
index 096fb3d83a54..000000000000
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/util/TypeCastSuite.scala
+++ /dev/null
@@ -1,234 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.execution.datasources.xml.util
-
-import java.math.BigDecimal
-import java.util.Locale
-
-import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, 
TimestampFormatter}
-import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
-import org.apache.spark.sql.catalyst.xml.{TypeCast, XmlOptions}
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-
-final class TypeCastSuite extends SharedSparkSession {
-
-  test("Can parse decimal type values") {
-    val options = new XmlOptions()
-    val stringValues = Seq("10.05", "1,000.01", "158,058,049.001")
-    val decimalValues = Seq(10.05, 1000.01, 158058049.001)
-    val decimalType = DecimalType.SYSTEM_DEFAULT
-
-    stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) =>
-      val dt = new BigDecimal(decimalVal.toString)
-      assert(TypeCast.castTo(strVal, decimalType, options) ===
-        Decimal(dt, dt.precision(), dt.scale()))
-    }
-  }
-
-  test("Nullable types are handled") {
-    val options = new XmlOptions(Map("nullValue" -> "-"))
-    for (t <- Seq(ByteType, ShortType, IntegerType, LongType, FloatType, 
DoubleType,
-                  BooleanType, TimestampType, DateType, StringType)) {
-      assert(TypeCast.castTo("-", t, options) === null)
-    }
-  }
-
-  test("String type should always return the same as the input") {
-    val options = new XmlOptions()
-    assert(TypeCast.castTo("", StringType, options) === 
UTF8String.fromString(""))
-  }
-
-  test("Types are cast correctly") {
-    val options = new XmlOptions()
-    assert(TypeCast.castTo("10", ByteType, options) === 10)
-    assert(TypeCast.castTo("10", ShortType, options) === 10)
-    assert(TypeCast.castTo("10", IntegerType, options) === 10)
-    assert(TypeCast.castTo("10", LongType, options) === 10)
-    assert(TypeCast.castTo("1.00", FloatType, options) === 1.0)
-    assert(TypeCast.castTo("1.00", DoubleType, options) === 1.0)
-    assert(TypeCast.castTo("true", BooleanType, options) === true)
-    assert(TypeCast.castTo("1", BooleanType, options) === true)
-    assert(TypeCast.castTo("false", BooleanType, options) === false)
-    assert(TypeCast.castTo("0", BooleanType, options) === false)
-
-    {
-      val ts = TypeCast.castTo("2002-05-30 21:46:54", TimestampType, options)
-      assert(ts === 1022820414000000L)
-      assert(ts ===
-        TimestampFormatter(None, 
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone),
-          Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30 21:46:54"))
-    }
-    {
-      val ts = TypeCast.castTo("2002-05-30T21:46:54", TimestampType, options)
-      assert(ts === 1022820414000000L)
-      assert(ts ===
-        TimestampFormatter(None, 
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone),
-          Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54"))
-    }
-    {
-      val ts = TypeCast.castTo("2002-05-30T21:46:54.1234", TimestampType, 
options)
-      assert(ts === 1022820414123400L)
-      assert(ts ===
-        TimestampFormatter(None, 
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone),
-          Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54.1234"))
-    }
-    {
-      val ts = TypeCast.castTo("2002-05-30T21:46:54Z", TimestampType, options)
-      assert(ts === 1022795214000000L)
-      assert(ts ===
-        TimestampFormatter(None, DateTimeUtils.getZoneId("UTC"),
-          Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54Z"))
-    }
-    {
-      val ts = TypeCast.castTo("2002-05-30T21:46:54-06:00", TimestampType, 
options)
-      assert(ts === 1022816814000000L)
-      assert(ts ===
-        TimestampFormatter(None, DateTimeUtils.getZoneId("-06:00"),
-          Locale.US, FAST_DATE_FORMAT, 
true).parse("2002-05-30T21:46:54-06:00"))
-    }
-    {
-      val ts = TypeCast.castTo("2002-05-30T21:46:54+06:00", TimestampType, 
options)
-      assert(ts === 1022773614000000L)
-      assert(ts ===
-        TimestampFormatter(None, DateTimeUtils.getZoneId("+06:00"),
-          Locale.US, FAST_DATE_FORMAT, 
true).parse("2002-05-30T21:46:54+06:00"))
-    }
-    {
-      val ts = TypeCast.castTo("2002-05-30T21:46:54.1234Z", TimestampType, 
options)
-      assert(ts === 1022795214123400L)
-      assert(ts ===
-        TimestampFormatter(None, DateTimeUtils.getZoneId("UTC"),
-          Locale.US, FAST_DATE_FORMAT, 
true).parse("2002-05-30T21:46:54.1234Z"))
-    }
-    {
-      val ts = TypeCast.castTo("2002-05-30T21:46:54.1234-06:00", 
TimestampType, options)
-      assert(ts === 1022816814123400L)
-      assert(ts ===
-        TimestampFormatter(None, DateTimeUtils.getZoneId("-06:00"),
-          Locale.US, FAST_DATE_FORMAT, 
true).parse("2002-05-30T21:46:54.1234-06:00"))
-    }
-    {
-      val ts = TypeCast.castTo("2002-05-30T21:46:54.1234+06:00", 
TimestampType, options)
-      assert(ts === 1022773614123400L)
-      assert(ts ===
-        TimestampFormatter(None, DateTimeUtils.getZoneId("+06:00"),
-          Locale.US, FAST_DATE_FORMAT, 
true).parse("2002-05-30T21:46:54.1234+06:00"))
-    }
-    {
-      val date = TypeCast.castTo("2002-09-24", DateType, options)
-      assert(date === 11954)
-      assert(date === DateFormatter(DateFormatter.defaultPattern,
-          Locale.US, FAST_DATE_FORMAT, true).parse("2002-09-24"))
-    }
-  }
-
-  test("Types with sign are cast correctly") {
-    val options = new XmlOptions()
-    assert(TypeCast.signSafeToInt("+10", options) === 10)
-    assert(TypeCast.signSafeToLong("-10", options) === -10)
-    assert(TypeCast.signSafeToFloat("1.00", options) === 1.0)
-    assert(TypeCast.signSafeToDouble("-1.00", options) === -1.0)
-  }
-
-  test("Types with sign are checked correctly") {
-    assert(TypeCast.isBoolean("true"))
-    assert(TypeCast.isInteger("10"))
-    assert(TypeCast.isLong("10"))
-    assert(TypeCast.isDouble("+10.1"))
-    assert(!TypeCast.isDouble("8E9D"))
-    assert(!TypeCast.isDouble("8E9F"))
-    val timestamp = "2015-01-01 00:00:00"
-    assert(TypeCast.isTimestamp(timestamp, new XmlOptions()))
-  }
-
-  test("Float and Double Types are cast correctly with Locale") {
-    val options = new XmlOptions()
-    val defaultLocale = Locale.getDefault
-    try {
-      Locale.setDefault(Locale.FRANCE)
-      assert(TypeCast.castTo("1,00", FloatType, options) === 1.0)
-      assert(TypeCast.castTo("1,00", DoubleType, options) === 1.0)
-    } finally {
-      Locale.setDefault(defaultLocale)
-    }
-  }
-
-  test("Parsing built-in timestamp formatters") {
-    val options = XmlOptions(Map())
-    val expectedResult =
-      TimestampFormatter(None, 
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone),
-      Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30 21:46:54")
-    assert(
-      TypeCast.castTo("2002-05-30 21:46:54", TimestampType, options) === 
expectedResult
-    )
-    assert(
-      TypeCast.castTo("2002-05-30T21:46:54", TimestampType, options) === 
expectedResult
-    )
-    assert(
-      TypeCast.castTo("2002-05-30T21:46:54+00:00", TimestampType, options) ===
-        TimestampFormatter(None, 
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone),
-          Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54+00:00")
-    )
-    assert(
-      TypeCast.castTo("2002-05-30T21:46:54.0000Z", TimestampType, options) ===
-        TimestampFormatter(None, 
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone),
-          Locale.US, FAST_DATE_FORMAT, true).parse("2002-05-30T21:46:54.0000Z")
-    )
-  }
-
-  test("Custom timestamp format is used to parse correctly") {
-    var options = XmlOptions(Map("timestampFormat" -> "MM-dd-yyyy HH:mm:ss", 
"timezone" -> "UTC"))
-    assert(
-      TypeCast.castTo("12-03-2011 10:15:30", TimestampType, options) ===
-        TimestampFormatter("MM-dd-yyyy HH:mm:ss", 
DateTimeUtils.getZoneId("UTC"),
-          Locale.US, FAST_DATE_FORMAT, true).parse("12-03-2011 10:15:30")
-    )
-
-    options = XmlOptions(Map("timestampFormat" -> "yyyy/MM/dd HH:mm:ss", 
"timezone" -> "UTC"))
-    assert(
-      TypeCast.castTo("2011/12/03 10:15:30", TimestampType, options) ===
-        TimestampFormatter("yyyy/MM/dd HH:mm:ss", 
DateTimeUtils.getZoneId("UTC"),
-          Locale.US, FAST_DATE_FORMAT, true).parse("2011/12/03 10:15:30")
-    )
-
-    options = XmlOptions(Map("timestampFormat" -> "yyyy/MM/dd HH:mm:ss",
-      "timezone" -> "Asia/Shanghai"))
-    assert(
-      TypeCast.castTo("2011/12/03 10:15:30", TimestampType, options) ===
-        TimestampFormatter("yyyy/MM/dd HH:mm:ss", 
DateTimeUtils.getZoneId("Asia/Shanghai"),
-          Locale.US, FAST_DATE_FORMAT, true).parse("2011/12/03 10:15:30")
-    )
-
-    options = XmlOptions(Map("timestampFormat" -> "yyyy/MM/dd HH:mm:ss",
-      "timezone" -> "Asia/Shanghai"))
-    assert(
-      TypeCast.castTo("2011/12/03 10:15:30", TimestampType, options) ===
-        TimestampFormatter("yyyy/MM/dd HH:mm:ss", 
DateTimeUtils.getZoneId("Asia/Shanghai"),
-          Locale.US, FAST_DATE_FORMAT, true).parse("2011/12/03 10:15:30")
-    )
-
-    options = XmlOptions(Map("timestampFormat" -> "yyyy/MM/dd HH:mm:ss"))
-    assert(TypeCast.castTo("2011/12/03 10:15:30", TimestampType, options) ===
-        TimestampFormatter("yyyy/MM/dd HH:mm:ss",
-          DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone),
-          Locale.US, FAST_DATE_FORMAT, true).parse("2011/12/03 10:15:30")
-    )
-  }
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to