This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 148f5335427c [SPARK-47297][SQL] Add collation support for format expressions 148f5335427c is described below commit 148f5335427c3aea39cbcce967e18a3b35a88687 Author: Uros Bojanic <157381213+uros...@users.noreply.github.com> AuthorDate: Tue May 7 23:00:30 2024 +0800 [SPARK-47297][SQL] Add collation support for format expressions ### What changes were proposed in this pull request? Introduce collation awareness for format expressions: to_number, try_to_number, to_char, space. ### Why are the changes needed? Add collation support for format expressions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use collated strings within arguments for format functions: to_number, try_to_number, to_char, space. ### How was this patch tested? E2e sql tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46423 from uros-db/format-expressions. Authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../expressions/numberFormatExpressions.scala | 14 ++- .../catalyst/expressions/stringExpressions.scala | 2 +- .../spark/sql/CollationSQLExpressionsSuite.scala | 132 ++++++++++++++++++++- 3 files changed, 141 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index 6d95d7e620a2..e914190c0645 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.util.ToNumberParser import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, DatetimeType, Decimal, DecimalType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -47,7 +49,8 @@ abstract class ToNumberBase(left: Expression, right: Expression, errorOnFail: Bo DecimalType.USER_DEFAULT } - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() @@ -247,8 +250,9 @@ object ToCharacterBuilder extends ExpressionBuilder { inputExpr.dataType match { case _: DatetimeType => DateFormatClass(inputExpr, format) case _: BinaryType => - if (!(format.dataType == StringType && format.foldable)) { - throw QueryCompilationErrors.nonFoldableArgumentError(funcName, "format", StringType) + if (!(format.dataType.isInstanceOf[StringType] && format.foldable)) { + throw QueryCompilationErrors.nonFoldableArgumentError(funcName, "format", + format.dataType) } val fmt = format.eval() if (fmt == null) { @@ -279,8 +283,8 @@ case class ToCharacter(left: Expression, right: Expression) } } - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringType) + override def dataType: DataType = SQLConf.get.defaultStringType + override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringTypeAnyCollation) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() if (inputTypeCheck.isSuccess) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 0769c8e609ec..c2ea17de1953 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1906,7 +1906,7 @@ case class StringRepeat(str: Expression, times: Expression) case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(IntegerType) override def nullSafeEval(s: Any): Any = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 596923d975a5..4314ff97a3cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql import scala.collection.immutable.Seq +import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{MapType, StringType} +import org.apache.spark.sql.types._ // scalastyle:off nonascii class CollationSQLExpressionsSuite @@ -330,6 +331,135 @@ class CollationSQLExpressionsSuite }) } + test("Support StringSpace expression with collation") { + case class StringSpaceTestCase( + input: Int, + collationName: String, + result: String + ) + + val testCases = Seq( + StringSpaceTestCase(1, "UTF8_BINARY", " "), + StringSpaceTestCase(2, "UTF8_BINARY_LCASE", " "), + StringSpaceTestCase(3, "UNICODE", " "), + StringSpaceTestCase(4, "UNICODE_CI", " ") + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select space(${t.input}) + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support ToNumber & TryToNumber expressions with collation") { + case class ToNumberTestCase( + input: String, + collationName: String, + format: String, + result: Any, + resultType: DataType + ) + + val testCases = Seq( + ToNumberTestCase("123", "UTF8_BINARY", "999", 123, DecimalType(3, 0)), + ToNumberTestCase("1", "UTF8_BINARY_LCASE", "0.00", 1.00, DecimalType(3, 2)), + ToNumberTestCase("99,999", "UNICODE", "99,999", 99999, DecimalType(5, 0)), + ToNumberTestCase("$14.99", "UNICODE_CI", "$99.99", 14.99, DecimalType(4, 2)) + ) + + // Supported collations (ToNumber) + testCases.foreach(t => { + val query = + s""" + |select to_number('${t.input}', '${t.format}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + assert(testQuery.schema.fields.head.dataType.sameType(t.resultType)) + } + }) + + // Supported collations (TryToNumber) + testCases.foreach(t => { + val query = + s""" + |select try_to_number('${t.input}', '${t.format}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + assert(testQuery.schema.fields.head.dataType.sameType(t.resultType)) + } + }) + } + + test("Handle invalid number for ToNumber variant expression with collation") { + // to_number should throw an exception if the conversion fails + val number = "xx" + val query = s"SELECT to_number('$number', '999');" + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + val e = intercept[SparkIllegalArgumentException] { + val testQuery = sql(query) + testQuery.collect() + } + assert(e.getErrorClass === "INVALID_FORMAT.MISMATCH_INPUT") + } + } + + test("Handle invalid number for TryToNumber variant expression with collation") { + // try_to_number shouldn't throw an exception if the conversion fails + val number = "xx" + val query = s"SELECT try_to_number('$number', '999');" + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") { + val testQuery = sql(query) + checkAnswer(testQuery, Row(null)) + } + } + + test("Support ToChar expression with collation") { + case class ToCharTestCase( + input: Int, + collationName: String, + format: String, + result: String + ) + + val testCases = Seq( + ToCharTestCase(12, "UTF8_BINARY", "999", " 12"), + ToCharTestCase(34, "UTF8_BINARY_LCASE", "000D00", "034.00"), + ToCharTestCase(56, "UNICODE", "$99.99", "$56.00"), + ToCharTestCase(78, "UNICODE_CI", "99D9S", "78.0+") + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select to_char(${t.input}, '${t.format}') + |""".stripMargin + // Result & data type + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + test("Support StringToMap expression with collation") { // Supported collations case class StringToMapTestCase[R](t: String, p: String, k: String, c: String, result: R) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org