This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 8c0a7ba82c98 [SPARK-48160][SQL] Add collation support for XPATH expressions 8c0a7ba82c98 is described below commit 8c0a7ba82c98c7f7e686c4ee81d2aad49cc7a6e0 Author: Uros Bojanic <157381213+uros...@users.noreply.github.com> AuthorDate: Wed May 15 14:24:46 2024 +0800 [SPARK-48160][SQL] Add collation support for XPATH expressions ### What changes were proposed in this pull request? Introduce collation awareness for XPath expressions: xpath_boolean, xpath_short, xpath_int, xpath_long, xpath_float, xpath_double, xpath_string, xpath. ### Why are the changes needed? Add collation support for Xpath expressions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use collated strings within arguments for XPath functions: xpath_boolean, xpath_short, xpath_int, xpath_long, xpath_float, xpath_double, xpath_string, xpath. ### How was this patch tested? E2e sql tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46508 from uros-db/xpath-expressions. Authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/expressions/xml/xpath.scala | 11 ++++-- .../spark/sql/CollationSQLExpressionsSuite.scala | 44 ++++++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index c3a285178c11..f65061e8d0ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -39,7 +41,8 @@ abstract class XPathExtract /** XPath expressions are always nullable, e.g. if the xml string is empty. */ override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def checkInputDataTypes(): TypeCheckResult = { if (!path.foldable) { @@ -47,7 +50,7 @@ abstract class XPathExtract errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("path"), - "inputType" -> toSQLType(StringType), + "inputType" -> toSQLType(StringTypeAnyCollation), "inputExpr" -> toSQLExpr(path) ) ) @@ -221,7 +224,7 @@ case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { // scalastyle:on line.size.limit case class XPathString(xml: Expression, path: Expression) extends XPathExtract { override def prettyName: String = "xpath_string" - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullSafeEval(xml: Any, path: Any): Any = { val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString) @@ -245,7 +248,7 @@ case class XPathString(xml: Expression, path: Expression) extends XPathExtract { // scalastyle:on line.size.limit case class XPathList(xml: Expression, path: Expression) extends XPathExtract { override def prettyName: String = "xpath" - override def dataType: DataType = ArrayType(StringType, containsNull = false) + override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType, containsNull = false) override def nullSafeEval(xml: Any, path: Any): Any = { val nodeList = xpathUtil.evalNodeList(xml.asInstanceOf[UTF8String].toString, pathString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 48c3853bb5cf..37dcdf9bd721 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -548,6 +548,50 @@ class CollationSQLExpressionsSuite }) } + test("Support XPath expressions with collation") { + case class XPathTestCase( + xml: String, + xpath: String, + functionName: String, + collationName: String, + result: Any, + resultType: DataType + ) + + val testCases = Seq( + XPathTestCase("<a><b>1</b></a>", "a/b", + "xpath_boolean", "UTF8_BINARY", true, BooleanType), + XPathTestCase("<A><B>1</B><B>2</B></A>", "sum(A/B)", + "xpath_short", "UTF8_BINARY", 3, ShortType), + XPathTestCase("<a><b>3</b><b>4</b></a>", "sum(a/b)", + "xpath_int", "UTF8_BINARY_LCASE", 7, IntegerType), + XPathTestCase("<A><B>5</B><B>6</B></A>", "sum(A/B)", + "xpath_long", "UTF8_BINARY_LCASE", 11, LongType), + XPathTestCase("<a><b>7</b><b>8</b></a>", "sum(a/b)", + "xpath_float", "UNICODE", 15.0, FloatType), + XPathTestCase("<A><B>9</B><B>0</B></A>", "sum(A/B)", + "xpath_double", "UNICODE", 9.0, DoubleType), + XPathTestCase("<a><b>b</b><c>cc</c></a>", "a/c", + "xpath_string", "UNICODE_CI", "cc", StringType("UNICODE_CI")), + XPathTestCase("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>", "a/b/text()", + "xpath", "UNICODE_CI", Array("b1", "b2", "b3"), ArrayType(StringType("UNICODE_CI"))) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select ${t.functionName}('${t.xml}', '${t.xpath}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + assert(testQuery.schema.fields.head.dataType.sameType(t.resultType)) + } + }) + } + test("Support StringSpace expression with collation") { case class StringSpaceTestCase( input: Int, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org