This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 07b84dd57e38 [SPARK-47476][SQL] Support REPLACE function to work with collated strings 07b84dd57e38 is described below commit 07b84dd57e38b6396bffaf6f756019e933512d32 Author: Milan Dankovic <milan.danko...@databricks.com> AuthorDate: Fri Apr 26 23:19:22 2024 +0800 [SPARK-47476][SQL] Support REPLACE function to work with collated strings ### What changes were proposed in this pull request? Extend built-in string functions to support non-binary, non-lowercase collation for: replace. ### Why are the changes needed? Update collation support for built-in string functions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use COLLATE within arguments for built-in string function REPLACE in Spark SQL queries, using non-binary collations such as UNICODE_CI. ### How was this patch tested? Unit tests for queries using StringReplace (`CollationStringExpressionsSuite.scala`). ### Was this patch authored or co-authored using generative AI tooling? No ### Algorithm explanation - StringSearch.next() returns position of the first character of `search` string in the `source` source. We need to convert this position to position in bytes so we can perform replace operation correctly. - For UTF8_BINARY_LCASE collation there is no corresponding collator so we have to implement custom logic (`lowercaseReplace`). It is done by performing matching on **lowercase strings** (`source & search`) and using that information to do operations on the **original** `source` string. String building is performed in the same way as for other non-binary collations. Similar logic can be found in existing `int find(UTF8String str, int start)` & `int indexOf(UTF8String v, int start)` methods. Closes #45704 from miland-db/miland-db/string-replace. Lead-authored-by: Milan Dankovic <milan.danko...@databricks.com> Co-authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/util/CollationSupport.java | 140 +++++++++++++++++++++ .../org/apache/spark/unsafe/types/UTF8String.java | 4 +- .../spark/unsafe/types/CollationSupportSuite.java | 38 ++++++ .../sql/catalyst/analysis/CollationTypeCasts.scala | 2 +- .../catalyst/expressions/stringExpressions.scala | 16 +-- .../sql/CollationStringExpressionsSuite.scala | 36 ++++++ 6 files changed, 226 insertions(+), 10 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 0c03faa0d23a..0fc37c169612 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -21,6 +21,7 @@ import com.ibm.icu.text.BreakIterator; import com.ibm.icu.text.StringSearch; import com.ibm.icu.util.ULocale; +import org.apache.spark.unsafe.UTF8StringBuilder; import org.apache.spark.unsafe.types.UTF8String; import java.util.ArrayList; @@ -364,6 +365,44 @@ public final class CollationSupport { } } + public static class StringReplace { + public static UTF8String exec(final UTF8String src, final UTF8String search, + final UTF8String replace, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(src, search, replace); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(src, search, replace); + } else { + return execICU(src, search, replace, collationId); + } + } + public static String genCode(final String src, final String search, final String replace, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringReplace.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s, %s)", src, search, replace); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s, %s)", src, search, replace); + } else { + return String.format(expr + "ICU(%s, %s, %s, %d)", src, search, replace, collationId); + } + } + public static UTF8String execBinary(final UTF8String src, final UTF8String search, + final UTF8String replace) { + return src.replace(search, replace); + } + public static UTF8String execLowercase(final UTF8String src, final UTF8String search, + final UTF8String replace) { + return CollationAwareUTF8String.lowercaseReplace(src, search, replace); + } + public static UTF8String execICU(final UTF8String src, final UTF8String search, + final UTF8String replace, final int collationId) { + return CollationAwareUTF8String.replace(src, search, replace, collationId); + } + } + // TODO: Add more collation-aware string expressions. /** @@ -401,6 +440,107 @@ public final class CollationSupport { private static class CollationAwareUTF8String { + private static UTF8String replace(final UTF8String src, final UTF8String search, + final UTF8String replace, final int collationId) { + // This collation aware implementation is based on existing implementation on UTF8String + if (src.numBytes() == 0 || search.numBytes() == 0) { + return src; + } + + StringSearch stringSearch = CollationFactory.getStringSearch(src, search, collationId); + + // Find the first occurrence of the search string. + int end = stringSearch.next(); + if (end == StringSearch.DONE) { + // Search string was not found, so string is unchanged. + return src; + } + + // Initialize byte positions + int c = 0; + int byteStart = 0; // position in byte + int byteEnd = 0; // position in byte + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + + // At least one match was found. Estimate space needed for result. + // The 16x multiplier here is chosen to match commons-lang3's implementation. + int increase = Math.max(0, Math.abs(replace.numBytes() - search.numBytes())) * 16; + final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); + while (end != StringSearch.DONE) { + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); + buf.append(replace); + + // Move byteStart to the beginning of the current match + byteStart = byteEnd; + int cs = c; + // Move cs to the end of the current match + // This is necessary because the search string may contain 'multi-character' characters + while (byteStart < src.numBytes() && cs < c + stringSearch.getMatchLength()) { + byteStart += UTF8String.numBytesForFirstByte(src.getByte(byteStart)); + cs += 1; + } + // Go to next match + end = stringSearch.next(); + // Update byte positions + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + } + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, + src.numBytes() - byteStart); + return buf.build(); + } + + private static UTF8String lowercaseReplace(final UTF8String src, final UTF8String search, + final UTF8String replace) { + if (src.numBytes() == 0 || search.numBytes() == 0) { + return src; + } + UTF8String lowercaseString = src.toLowerCase(); + UTF8String lowercaseSearch = search.toLowerCase(); + + int start = 0; + int end = lowercaseString.indexOf(lowercaseSearch, 0); + if (end == -1) { + // Search string was not found, so string is unchanged. + return src; + } + + // Initialize byte positions + int c = 0; + int byteStart = 0; // position in byte + int byteEnd = 0; // position in byte + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + + // At least one match was found. Estimate space needed for result. + // The 16x multiplier here is chosen to match commons-lang3's implementation. + int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16; + final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); + while (end != -1) { + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); + buf.append(replace); + // Update character positions + start = end + lowercaseSearch.numChars(); + end = lowercaseString.indexOf(lowercaseSearch, start); + // Update byte positions + byteStart = byteEnd + search.numBytes(); + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + } + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, + src.numBytes() - byteStart); + return buf.build(); + } + private static String toUpperCase(final String target, final int collationId) { ULocale locale = CollationFactory.fetchCollation(collationId) .collator.getLocale(ULocale.ACTUAL_LOCALE); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 8ceeddb0c3dd..ca6198df2bbf 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -224,7 +224,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point */ - private static int numBytesForFirstByte(final byte b) { + public static int numBytesForFirstByte(final byte b) { final int offset = b & 0xFF; byte numBytes = bytesOfCodePointInUTF8[offset]; return (numBytes == 0) ? 1: numBytes; // Skip the first byte disallowed in UTF-8 @@ -382,7 +382,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, /** * Returns the byte at position `i`. */ - private byte getByte(int i) { + public byte getByte(int i) { return Platform.getByte(base, offset + i); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 72edd3e06f9c..6c79fc821317 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -614,6 +614,44 @@ public class CollationSupportSuite { assertFindInSet("İo", "ab,i̇o,12", "UNICODE_CI", 2); } + private void assertReplace(String source, String search, String replace, String collationName, + String expected) throws SparkException { + UTF8String src = UTF8String.fromString(source); + UTF8String sear = UTF8String.fromString(search); + UTF8String repl = UTF8String.fromString(replace); + int collationId = CollationFactory.collationNameToId(collationName); + assertEquals(expected, CollationSupport.StringReplace + .exec(src, sear, repl, collationId).toString()); + } + + @Test + public void testReplace() throws SparkException { + assertReplace("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace"); + assertReplace("replace", "pl", "", "UTF8_BINARY", "reace"); + assertReplace("repl世ace", "Pl", "", "UTF8_BINARY", "repl世ace"); + assertReplace("replace", "", "123", "UTF8_BINARY", "replace"); + assertReplace("abcabc", "b", "12", "UTF8_BINARY", "a12ca12c"); + assertReplace("abcdabcd", "bc", "", "UTF8_BINARY", "adad"); + assertReplace("r世eplace", "pl", "xx", "UTF8_BINARY_LCASE", "r世exxace"); + assertReplace("repl世ace", "PL", "AB", "UTF8_BINARY_LCASE", "reAB世ace"); + assertReplace("Replace", "", "123", "UTF8_BINARY_LCASE", "Replace"); + assertReplace("re世place", "世", "x", "UTF8_BINARY_LCASE", "rexplace"); + assertReplace("abcaBc", "B", "12", "UTF8_BINARY_LCASE", "a12ca12c"); + assertReplace("AbcdabCd", "Bc", "", "UTF8_BINARY_LCASE", "Adad"); + assertReplace("re世place", "plx", "123", "UNICODE", "re世place"); + assertReplace("世Replace", "re", "", "UNICODE", "世Replace"); + assertReplace("replace世", "", "123", "UNICODE", "replace世"); + assertReplace("aBc世abc", "b", "12", "UNICODE", "aBc世a12c"); + assertReplace("abcdabcd", "bc", "", "UNICODE", "adad"); + assertReplace("replace", "plx", "123", "UNICODE_CI", "replace"); + assertReplace("Replace", "re", "", "UNICODE_CI", "place"); + assertReplace("replace", "", "123", "UNICODE_CI", "replace"); + assertReplace("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c"); + assertReplace("a世Bcdabcd", "bC", "", "UNICODE_CI", "a世dad"); + assertReplace("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx"); + assertReplace("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"); + } + // TODO: Test more collation-aware string expressions. /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index c7ca5607481d..3ae251e56772 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -64,7 +64,7 @@ object CollationTypeCasts extends TypeCoercionRule { case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | - _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask) => + _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: StringReplace) => val newChildren = collateToSingleType(otherExpr.children) otherExpr.withNewChildren(newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 612082c56096..135345990e51 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -710,23 +710,25 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + final lazy val collationId: Int = first.dataType.asInstanceOf[StringType].collationId + def this(srcExpr: Expression, searchExpr: Expression) = { this(srcExpr, searchExpr, Literal("")) } override def nullSafeEval(srcEval: Any, searchEval: Any, replaceEval: Any): Any = { - srcEval.asInstanceOf[UTF8String].replace( - searchEval.asInstanceOf[UTF8String], replaceEval.asInstanceOf[UTF8String]) + CollationSupport.StringReplace.exec(srcEval.asInstanceOf[UTF8String], + searchEval.asInstanceOf[UTF8String], replaceEval.asInstanceOf[UTF8String], collationId); } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (src, search, replace) => { - s"""${ev.value} = $src.replace($search, $replace);""" - }) + defineCodeGen(ctx, ev, (src, search, replace) => + CollationSupport.StringReplace.genCode(src, search, replace, collationId)) } - override def dataType: DataType = StringType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) + override def dataType: DataType = srcExpr.dataType + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) override def first: Expression = srcExpr override def second: Expression = searchExpr override def third: Expression = replaceExpr diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 2b6761475a43..305c51c0b703 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, IntegerType, StringType} @@ -217,6 +218,41 @@ class CollationStringExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } + test("Support Replace string expression with collation") { + case class ReplaceTestCase[R](source: String, search: String, replace: String, + c: String, result: R) + val testCases = Seq( + // scalastyle:off + ReplaceTestCase("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace"), + ReplaceTestCase("repl世ace", "PL", "AB", "UTF8_BINARY_LCASE", "reAB世ace"), + ReplaceTestCase("abcdabcd", "bc", "", "UNICODE", "adad"), + ReplaceTestCase("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c"), + ReplaceTestCase("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"), + ReplaceTestCase("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx") + // scalastyle:on + ) + testCases.foreach(t => { + val query = s"SELECT replace(collate('${t.source}','${t.c}'),collate('${t.search}'," + + s"'${t.c}'),collate('${t.replace}','${t.c}'))" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType( + StringType(CollationFactory.collationNameToId(t.c)))) + // Implicit casting + checkAnswer(sql(s"SELECT replace(collate('${t.source}','${t.c}'),'${t.search}'," + + s"'${t.replace}')"), Row(t.result)) + checkAnswer(sql(s"SELECT replace('${t.source}',collate('${t.search}','${t.c}')," + + s"'${t.replace}')"), Row(t.result)) + checkAnswer(sql(s"SELECT replace('${t.source}','${t.search}'," + + s"collate('${t.replace}','${t.c}'))"), Row(t.result)) + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql("SELECT startswith(collate('abcde', 'UTF8_BINARY_LCASE'),collate('C', 'UNICODE_CI'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } + test("Support EndsWith string expression with collation") { // Supported collations case class EndsWithTestCase[R](l: String, r: String, c: String, result: R) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org