This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8c0a7ba82c98 [SPARK-48160][SQL] Add collation support for XPATH 
expressions
8c0a7ba82c98 is described below

commit 8c0a7ba82c98c7f7e686c4ee81d2aad49cc7a6e0
Author: Uros Bojanic <157381213+uros...@users.noreply.github.com>
AuthorDate: Wed May 15 14:24:46 2024 +0800

    [SPARK-48160][SQL] Add collation support for XPATH expressions
    
    ### What changes were proposed in this pull request?
    Introduce collation awareness for XPath expressions: xpath_boolean, 
xpath_short, xpath_int, xpath_long, xpath_float, xpath_double, xpath_string, 
xpath.
    
    ### Why are the changes needed?
    Add collation support for Xpath expressions in Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users should now be able to use collated strings within arguments for 
XPath functions: xpath_boolean, xpath_short, xpath_int, xpath_long, 
xpath_float, xpath_double, xpath_string, xpath.
    
    ### How was this patch tested?
    E2e sql tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #46508 from uros-db/xpath-expressions.
    
    Authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/expressions/xml/xpath.scala | 11 ++++--
 .../spark/sql/CollationSQLExpressionsSuite.scala   | 44 ++++++++++++++++++++++
 2 files changed, 51 insertions(+), 4 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
index c3a285178c11..f65061e8d0ea 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
@@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.Cast._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -39,7 +41,8 @@ abstract class XPathExtract
   /** XPath expressions are always nullable, e.g. if the xml string is empty. 
*/
   override def nullable: Boolean = true
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringTypeAnyCollation, StringTypeAnyCollation)
 
   override def checkInputDataTypes(): TypeCheckResult = {
     if (!path.foldable) {
@@ -47,7 +50,7 @@ abstract class XPathExtract
         errorSubClass = "NON_FOLDABLE_INPUT",
         messageParameters = Map(
           "inputName" -> toSQLId("path"),
-          "inputType" -> toSQLType(StringType),
+          "inputType" -> toSQLType(StringTypeAnyCollation),
           "inputExpr" -> toSQLExpr(path)
         )
       )
@@ -221,7 +224,7 @@ case class XPathDouble(xml: Expression, path: Expression) 
extends XPathExtract {
 // scalastyle:on line.size.limit
 case class XPathString(xml: Expression, path: Expression) extends XPathExtract 
{
   override def prettyName: String = "xpath_string"
-  override def dataType: DataType = StringType
+  override def dataType: DataType = SQLConf.get.defaultStringType
 
   override def nullSafeEval(xml: Any, path: Any): Any = {
     val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, 
pathString)
@@ -245,7 +248,7 @@ case class XPathString(xml: Expression, path: Expression) 
extends XPathExtract {
 // scalastyle:on line.size.limit
 case class XPathList(xml: Expression, path: Expression) extends XPathExtract {
   override def prettyName: String = "xpath"
-  override def dataType: DataType = ArrayType(StringType, containsNull = false)
+  override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType, 
containsNull = false)
 
   override def nullSafeEval(xml: Any, path: Any): Any = {
     val nodeList = 
xpathUtil.evalNodeList(xml.asInstanceOf[UTF8String].toString, pathString)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 48c3853bb5cf..37dcdf9bd721 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -548,6 +548,50 @@ class CollationSQLExpressionsSuite
     })
   }
 
+  test("Support XPath expressions with collation") {
+    case class XPathTestCase(
+      xml: String,
+      xpath: String,
+      functionName: String,
+      collationName: String,
+      result: Any,
+      resultType: DataType
+    )
+
+    val testCases = Seq(
+      XPathTestCase("<a><b>1</b></a>", "a/b",
+        "xpath_boolean", "UTF8_BINARY", true, BooleanType),
+      XPathTestCase("<A><B>1</B><B>2</B></A>", "sum(A/B)",
+        "xpath_short", "UTF8_BINARY", 3, ShortType),
+      XPathTestCase("<a><b>3</b><b>4</b></a>", "sum(a/b)",
+        "xpath_int", "UTF8_BINARY_LCASE", 7, IntegerType),
+      XPathTestCase("<A><B>5</B><B>6</B></A>", "sum(A/B)",
+        "xpath_long", "UTF8_BINARY_LCASE", 11, LongType),
+      XPathTestCase("<a><b>7</b><b>8</b></a>", "sum(a/b)",
+        "xpath_float", "UNICODE", 15.0, FloatType),
+      XPathTestCase("<A><B>9</B><B>0</B></A>", "sum(A/B)",
+        "xpath_double", "UNICODE", 9.0, DoubleType),
+      XPathTestCase("<a><b>b</b><c>cc</c></a>", "a/c",
+        "xpath_string", "UNICODE_CI", "cc", StringType("UNICODE_CI")),
+      XPathTestCase("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>", 
"a/b/text()",
+        "xpath", "UNICODE_CI", Array("b1", "b2", "b3"), 
ArrayType(StringType("UNICODE_CI")))
+    )
+
+    // Supported collations
+    testCases.foreach(t => {
+      val query =
+        s"""
+           |select ${t.functionName}('${t.xml}', '${t.xpath}')
+           |""".stripMargin
+      // Result & data type
+      withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
+        val testQuery = sql(query)
+        checkAnswer(testQuery, Row(t.result))
+        assert(testQuery.schema.fields.head.dataType.sameType(t.resultType))
+      }
+    })
+  }
+
   test("Support StringSpace expression with collation") {
     case class StringSpaceTestCase(
       input: Int,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to