This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b9f2270f5b0b [SPARK-47352][SQL] Fix Upper, Lower, InitCap collation awareness b9f2270f5b0b is described below commit b9f2270f5b0ba6ea1fb1cdf3225fa626ab91540b Author: Mihailo Milosevic <mihailo.milose...@databricks.com> AuthorDate: Tue Apr 23 16:28:33 2024 +0800 [SPARK-47352][SQL] Fix Upper, Lower, InitCap collation awareness ### What changes were proposed in this pull request? Add support for Locale aware expressions. ### Why are the changes needed? This is needed as some future collations might use different Locales then default. ### Does this PR introduce _any_ user-facing change? Yes, we follow ICU implementations for collations that are non native. ### How was this patch tested? Tests for Upper, Lower and InitCap already exist. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46104 from mihailom-db/SPARK-47352. Authored-by: Mihailo Milosevic <mihailo.milose...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/util/CollationSupport.java | 108 +++++++++++++++ .../spark/unsafe/types/CollationSupportSuite.java | 151 +++++++++++++++++++++ .../catalyst/expressions/stringExpressions.scala | 24 ++-- 3 files changed, 271 insertions(+), 12 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index d54e297413f4..b28321230840 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -16,7 +16,10 @@ */ package org.apache.spark.sql.catalyst.util; +import com.ibm.icu.lang.UCharacter; +import com.ibm.icu.text.BreakIterator; import com.ibm.icu.text.StringSearch; +import com.ibm.icu.util.ULocale; import org.apache.spark.unsafe.types.UTF8String; @@ -144,6 +147,93 @@ public final class CollationSupport { } } + public static class Upper { + public static UTF8String exec(final UTF8String v, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) { + return execUTF8(v); + } else { + return execICU(v, collationId); + } + } + public static String genCode(final String v, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.Upper.exec"; + if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) { + return String.format(expr + "UTF8(%s)", v); + } else { + return String.format(expr + "ICU(%s, %d)", v, collationId); + } + } + public static UTF8String execUTF8(final UTF8String v) { + return v.toUpperCase(); + } + public static UTF8String execICU(final UTF8String v, final int collationId) { + return UTF8String.fromString(CollationAwareUTF8String.toUpperCase(v.toString(), collationId)); + } + } + + public static class Lower { + public static UTF8String exec(final UTF8String v, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) { + return execUTF8(v); + } else { + return execICU(v, collationId); + } + } + public static String genCode(final String v, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.Lower.exec"; + if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) { + return String.format(expr + "UTF8(%s)", v); + } else { + return String.format(expr + "ICU(%s, %d)", v, collationId); + } + } + public static UTF8String execUTF8(final UTF8String v) { + return v.toLowerCase(); + } + public static UTF8String execICU(final UTF8String v, final int collationId) { + return UTF8String.fromString(CollationAwareUTF8String.toLowerCase(v.toString(), collationId)); + } + } + + public static class InitCap { + public static UTF8String exec(final UTF8String v, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) { + return execUTF8(v); + } else { + return execICU(v, collationId); + } + } + + public static String genCode(final String v, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.InitCap.exec"; + if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) { + return String.format(expr + "UTF8(%s)", v); + } else { + return String.format(expr + "ICU(%s, %d)", v, collationId); + } + } + + public static UTF8String execUTF8(final UTF8String v) { + return v.toLowerCase().toTitleCase(); + } + + public static UTF8String execICU(final UTF8String v, final int collationId) { + return UTF8String.fromString( + CollationAwareUTF8String.toTitleCase( + CollationAwareUTF8String.toLowerCase( + v.toString(), + collationId + ), + collationId)); + } + } + public static class FindInSet { public static int exec(final UTF8String word, final UTF8String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); @@ -234,6 +324,24 @@ public final class CollationSupport { private static class CollationAwareUTF8String { + private static String toUpperCase(final String target, final int collationId) { + ULocale locale = CollationFactory.fetchCollation(collationId) + .collator.getLocale(ULocale.ACTUAL_LOCALE); + return UCharacter.toUpperCase(locale, target); + } + + private static String toLowerCase(final String target, final int collationId) { + ULocale locale = CollationFactory.fetchCollation(collationId) + .collator.getLocale(ULocale.ACTUAL_LOCALE); + return UCharacter.toLowerCase(locale, target); + } + + private static String toTitleCase(final String target, final int collationId) { + ULocale locale = CollationFactory.fetchCollation(collationId) + .collator.getLocale(ULocale.ACTUAL_LOCALE); + return UCharacter.toTitleCase(locale, target, BreakIterator.getWordInstance(locale)); + } + private static int findInSet(final UTF8String match, final UTF8String set, int collationId) { if (match.contains(UTF8String.fromString(","))) { return 0; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 36acf1c9b7a6..3fca7296b832 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -261,6 +261,157 @@ public class CollationSupportSuite { assertEndsWith("The i̇o", "İo", "UNICODE_CI", true); } + + private void assertUpper(String target, String collationName, String expected) + throws SparkException { + UTF8String target_utf8 = UTF8String.fromString(target); + UTF8String expected_utf8 = UTF8String.fromString(expected); + int collationId = CollationFactory.collationNameToId(collationName); + assertEquals(expected_utf8, CollationSupport.Upper.exec(target_utf8, collationId)); + } + + @Test + public void testUpper() throws SparkException { + // Edge cases + assertUpper("", "UTF8_BINARY", ""); + assertUpper("", "UTF8_BINARY_LCASE", ""); + assertUpper("", "UNICODE", ""); + assertUpper("", "UNICODE_CI", ""); + // Basic tests + assertUpper("abcde", "UTF8_BINARY", "ABCDE"); + assertUpper("abcde", "UTF8_BINARY_LCASE", "ABCDE"); + assertUpper("abcde", "UNICODE", "ABCDE"); + assertUpper("abcde", "UNICODE_CI", "ABCDE"); + // Uppercase present + assertUpper("AbCdE", "UTF8_BINARY", "ABCDE"); + assertUpper("aBcDe", "UTF8_BINARY", "ABCDE"); + assertUpper("AbCdE", "UTF8_BINARY_LCASE", "ABCDE"); + assertUpper("aBcDe", "UTF8_BINARY_LCASE", "ABCDE"); + assertUpper("AbCdE", "UNICODE", "ABCDE"); + assertUpper("aBcDe", "UNICODE", "ABCDE"); + assertUpper("AbCdE", "UNICODE_CI", "ABCDE"); + assertUpper("aBcDe", "UNICODE_CI", "ABCDE"); + // Accent letters + assertUpper("aBćDe","UTF8_BINARY", "ABĆDE"); + assertUpper("aBćDe","UTF8_BINARY_LCASE", "ABĆDE"); + assertUpper("aBćDe","UNICODE", "ABĆDE"); + assertUpper("aBćDe","UNICODE_CI", "ABĆDE"); + // Variable byte length characters + assertUpper("ab世De", "UTF8_BINARY", "AB世DE"); + assertUpper("äbćδe", "UTF8_BINARY", "ÄBĆΔE"); + assertUpper("ab世De", "UTF8_BINARY_LCASE", "AB世DE"); + assertUpper("äbćδe", "UTF8_BINARY_LCASE", "ÄBĆΔE"); + assertUpper("ab世De", "UNICODE", "AB世DE"); + assertUpper("äbćδe", "UNICODE", "ÄBĆΔE"); + assertUpper("ab世De", "UNICODE_CI", "AB世DE"); + assertUpper("äbćδe", "UNICODE_CI", "ÄBĆΔE"); + // Case-variable character length + assertUpper("i̇o", "UTF8_BINARY","İO"); + assertUpper("i̇o", "UTF8_BINARY_LCASE","İO"); + assertUpper("i̇o", "UNICODE","İO"); + assertUpper("i̇o", "UNICODE_CI","İO"); + } + + private void assertLower(String target, String collationName, String expected) + throws SparkException { + UTF8String target_utf8 = UTF8String.fromString(target); + UTF8String expected_utf8 = UTF8String.fromString(expected); + int collationId = CollationFactory.collationNameToId(collationName); + assertEquals(expected_utf8, CollationSupport.Lower.exec(target_utf8, collationId)); + } + + @Test + public void testLower() throws SparkException { + // Edge cases + assertLower("", "UTF8_BINARY", ""); + assertLower("", "UTF8_BINARY_LCASE", ""); + assertLower("", "UNICODE", ""); + assertLower("", "UNICODE_CI", ""); + // Basic tests + assertLower("ABCDE", "UTF8_BINARY", "abcde"); + assertLower("ABCDE", "UTF8_BINARY_LCASE", "abcde"); + assertLower("ABCDE", "UNICODE", "abcde"); + assertLower("ABCDE", "UNICODE_CI", "abcde"); + // Uppercase present + assertLower("AbCdE", "UTF8_BINARY", "abcde"); + assertLower("aBcDe", "UTF8_BINARY", "abcde"); + assertLower("AbCdE", "UTF8_BINARY_LCASE", "abcde"); + assertLower("aBcDe", "UTF8_BINARY_LCASE", "abcde"); + assertLower("AbCdE", "UNICODE", "abcde"); + assertLower("aBcDe", "UNICODE", "abcde"); + assertLower("AbCdE", "UNICODE_CI", "abcde"); + assertLower("aBcDe", "UNICODE_CI", "abcde"); + // Accent letters + assertLower("AbĆdE","UTF8_BINARY", "abćde"); + assertLower("AbĆdE","UTF8_BINARY_LCASE", "abćde"); + assertLower("AbĆdE","UNICODE", "abćde"); + assertLower("AbĆdE","UNICODE_CI", "abćde"); + // Variable byte length characters + assertLower("aB世De", "UTF8_BINARY", "ab世de"); + assertLower("ÄBĆΔE", "UTF8_BINARY", "äbćδe"); + assertLower("aB世De", "UTF8_BINARY_LCASE", "ab世de"); + assertLower("ÄBĆΔE", "UTF8_BINARY_LCASE", "äbćδe"); + assertLower("aB世De", "UNICODE", "ab世de"); + assertLower("ÄBĆΔE", "UNICODE", "äbćδe"); + assertLower("aB世De", "UNICODE_CI", "ab世de"); + assertLower("ÄBĆΔE", "UNICODE_CI", "äbćδe"); + // Case-variable character length + assertLower("İo", "UTF8_BINARY","i̇o"); + assertLower("İo", "UTF8_BINARY_LCASE","i̇o"); + assertLower("İo", "UNICODE","i̇o"); + assertLower("İo", "UNICODE_CI","i̇o"); + } + + private void assertInitCap(String target, String collationName, String expected) + throws SparkException { + UTF8String target_utf8 = UTF8String.fromString(target); + UTF8String expected_utf8 = UTF8String.fromString(expected); + int collationId = CollationFactory.collationNameToId(collationName); + assertEquals(expected_utf8, CollationSupport.InitCap.exec(target_utf8, collationId)); + } + + @Test + public void testInitCap() throws SparkException { + // Edge cases + assertInitCap("", "UTF8_BINARY", ""); + assertInitCap("", "UTF8_BINARY_LCASE", ""); + assertInitCap("", "UNICODE", ""); + assertInitCap("", "UNICODE_CI", ""); + // Basic tests + assertInitCap("ABCDE", "UTF8_BINARY", "Abcde"); + assertInitCap("ABCDE", "UTF8_BINARY_LCASE", "Abcde"); + assertInitCap("ABCDE", "UNICODE", "Abcde"); + assertInitCap("ABCDE", "UNICODE_CI", "Abcde"); + // Uppercase present + assertInitCap("AbCdE", "UTF8_BINARY", "Abcde"); + assertInitCap("aBcDe", "UTF8_BINARY", "Abcde"); + assertInitCap("AbCdE", "UTF8_BINARY_LCASE", "Abcde"); + assertInitCap("aBcDe", "UTF8_BINARY_LCASE", "Abcde"); + assertInitCap("AbCdE", "UNICODE", "Abcde"); + assertInitCap("aBcDe", "UNICODE", "Abcde"); + assertInitCap("AbCdE", "UNICODE_CI", "Abcde"); + assertInitCap("aBcDe", "UNICODE_CI", "Abcde"); + // Accent letters + assertInitCap("AbĆdE", "UTF8_BINARY", "Abćde"); + assertInitCap("AbĆdE", "UTF8_BINARY_LCASE", "Abćde"); + assertInitCap("AbĆdE", "UNICODE", "Abćde"); + assertInitCap("AbĆdE", "UNICODE_CI", "Abćde"); + // Variable byte length characters + assertInitCap("aB 世 De", "UTF8_BINARY", "Ab 世 De"); + assertInitCap("ÄBĆΔE", "UTF8_BINARY", "Äbćδe"); + assertInitCap("aB 世 De", "UTF8_BINARY_LCASE", "Ab 世 De"); + assertInitCap("ÄBĆΔE", "UTF8_BINARY_LCASE", "Äbćδe"); + assertInitCap("aB 世 De", "UNICODE", "Ab 世 De"); + assertInitCap("ÄBĆΔE", "UNICODE", "Äbćδe"); + assertInitCap("aB 世 de", "UNICODE_CI", "Ab 世 De"); + assertInitCap("ÄBĆΔE", "UNICODE_CI", "Äbćδe"); + // Case-variable character length + assertInitCap("İo", "UTF8_BINARY", "İo"); + assertInitCap("İo", "UTF8_BINARY_LCASE", "İo"); + assertInitCap("İo", "UNICODE", "İo"); + assertInitCap("İo", "UNICODE_CI", "İo"); + } + private void assertStringInstr(String string, String substring, String collationName, Integer expected) throws SparkException { UTF8String str = UTF8String.fromString(string); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index cd21a6f5fdc2..fd4fc7a54229 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -453,14 +453,14 @@ trait String2StringExpression extends ImplicitCastInputTypes { case class Upper(child: Expression) extends UnaryExpression with String2StringExpression with NullIntolerant { - // scalastyle:off caselocale - override def convert(v: UTF8String): UTF8String = v.toUpperCase - // scalastyle:on caselocale + final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId + + override def convert(v: UTF8String): UTF8String = CollationSupport.Upper.exec(v, collationId) final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") + defineCodeGen(ctx, ev, c => CollationSupport.Upper.genCode(c, collationId)) } override protected def withNewChildInternal(newChild: Expression): Upper = copy(child = newChild) @@ -481,14 +481,14 @@ case class Upper(child: Expression) case class Lower(child: Expression) extends UnaryExpression with String2StringExpression with NullIntolerant { - // scalastyle:off caselocale - override def convert(v: UTF8String): UTF8String = v.toLowerCase - // scalastyle:on caselocale + final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId + + override def convert(v: UTF8String): UTF8String = CollationSupport.Lower.exec(v, collationId) final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") + defineCodeGen(ctx, ev, c => CollationSupport.Lower.genCode(c, collationId)) } override def prettyName: String = @@ -1824,16 +1824,16 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) override def dataType: DataType = child.dataType override def nullSafeEval(string: Any): Any = { - // scalastyle:off caselocale - string.asInstanceOf[UTF8String].toLowerCase.toTitleCase - // scalastyle:on caselocale + CollationSupport.InitCap.exec(string.asInstanceOf[UTF8String], collationId) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()") + defineCodeGen(ctx, ev, str => CollationSupport.InitCap.genCode(str, collationId)) } override protected def withNewChildInternal(newChild: Expression): InitCap = --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org