This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 07b84dd57e38 [SPARK-47476][SQL] Support REPLACE function to work with 
collated strings
07b84dd57e38 is described below

commit 07b84dd57e38b6396bffaf6f756019e933512d32
Author: Milan Dankovic <milan.danko...@databricks.com>
AuthorDate: Fri Apr 26 23:19:22 2024 +0800

    [SPARK-47476][SQL] Support REPLACE function to work with collated strings
    
    ### What changes were proposed in this pull request?
    Extend built-in string functions to support non-binary, non-lowercase 
collation for: replace.
    
    ### Why are the changes needed?
    Update collation support for built-in string functions in Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users should now be able to use COLLATE within arguments for built-in 
string function REPLACE in Spark SQL queries, using non-binary collations such 
as UNICODE_CI.
    
    ### How was this patch tested?
    Unit tests for queries using StringReplace 
(`CollationStringExpressionsSuite.scala`).
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    ### Algorithm explanation
    
    - StringSearch.next() returns position of the first character of `search` 
string in the `source` source. We need to convert this position to position in 
bytes so we can perform replace operation correctly.
    - For UTF8_BINARY_LCASE collation there is no corresponding collator so we 
have to implement custom logic (`lowercaseReplace`). It is done by performing 
matching on **lowercase strings** (`source & search`) and using that 
information to do operations on the **original** `source` string. String 
building is performed in the same way as for other non-binary collations.
    
    Similar logic can be found in existing `int find(UTF8String str, int 
start)` & `int indexOf(UTF8String v, int start)` methods.
    
    Closes #45704 from miland-db/miland-db/string-replace.
    
    Lead-authored-by: Milan Dankovic <milan.danko...@databricks.com>
    Co-authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/util/CollationSupport.java  | 140 +++++++++++++++++++++
 .../org/apache/spark/unsafe/types/UTF8String.java  |   4 +-
 .../spark/unsafe/types/CollationSupportSuite.java  |  38 ++++++
 .../sql/catalyst/analysis/CollationTypeCasts.scala |   2 +-
 .../catalyst/expressions/stringExpressions.scala   |  16 +--
 .../sql/CollationStringExpressionsSuite.scala      |  36 ++++++
 6 files changed, 226 insertions(+), 10 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
index 0c03faa0d23a..0fc37c169612 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -21,6 +21,7 @@ import com.ibm.icu.text.BreakIterator;
 import com.ibm.icu.text.StringSearch;
 import com.ibm.icu.util.ULocale;
 
+import org.apache.spark.unsafe.UTF8StringBuilder;
 import org.apache.spark.unsafe.types.UTF8String;
 
 import java.util.ArrayList;
@@ -364,6 +365,44 @@ public final class CollationSupport {
     }
   }
 
+  public static class StringReplace {
+    public static UTF8String exec(final UTF8String src, final UTF8String 
search,
+        final UTF8String replace, final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(src, search, replace);
+      } else if (collation.supportsLowercaseEquality) {
+        return execLowercase(src, search, replace);
+      } else {
+        return execICU(src, search, replace, collationId);
+      }
+    }
+    public static String genCode(final String src, final String search, final 
String replace,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.StringReplace.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s, %s, %s)", src, search, 
replace);
+      } else if (collation.supportsLowercaseEquality) {
+        return String.format(expr + "Lowercase(%s, %s, %s)", src, search, 
replace);
+      } else {
+        return String.format(expr + "ICU(%s, %s, %s, %d)", src, search, 
replace, collationId);
+      }
+    }
+    public static UTF8String execBinary(final UTF8String src, final UTF8String 
search,
+        final UTF8String replace) {
+      return src.replace(search, replace);
+    }
+    public static UTF8String execLowercase(final UTF8String src, final 
UTF8String search,
+        final UTF8String replace) {
+      return CollationAwareUTF8String.lowercaseReplace(src, search, replace);
+    }
+    public static UTF8String execICU(final UTF8String src, final UTF8String 
search,
+        final UTF8String replace, final int collationId) {
+      return CollationAwareUTF8String.replace(src, search, replace, 
collationId);
+    }
+  }
+
   // TODO: Add more collation-aware string expressions.
 
   /**
@@ -401,6 +440,107 @@ public final class CollationSupport {
 
   private static class CollationAwareUTF8String {
 
+    private static UTF8String replace(final UTF8String src, final UTF8String 
search,
+        final UTF8String replace, final int collationId) {
+      // This collation aware implementation is based on existing 
implementation on UTF8String
+      if (src.numBytes() == 0 || search.numBytes() == 0) {
+        return src;
+      }
+
+      StringSearch stringSearch = CollationFactory.getStringSearch(src, 
search, collationId);
+
+      // Find the first occurrence of the search string.
+      int end = stringSearch.next();
+      if (end == StringSearch.DONE) {
+        // Search string was not found, so string is unchanged.
+        return src;
+      }
+
+      // Initialize byte positions
+      int c = 0;
+      int byteStart = 0; // position in byte
+      int byteEnd = 0; // position in byte
+      while (byteEnd < src.numBytes() && c < end) {
+        byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
+        c += 1;
+      }
+
+      // At least one match was found. Estimate space needed for result.
+      // The 16x multiplier here is chosen to match commons-lang3's 
implementation.
+      int increase = Math.max(0, Math.abs(replace.numBytes() - 
search.numBytes())) * 16;
+      final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + 
increase);
+      while (end != StringSearch.DONE) {
+        buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, 
byteEnd - byteStart);
+        buf.append(replace);
+
+        // Move byteStart to the beginning of the current match
+        byteStart = byteEnd;
+        int cs = c;
+        // Move cs to the end of the current match
+        // This is necessary because the search string may contain 
'multi-character' characters
+        while (byteStart < src.numBytes() && cs < c + 
stringSearch.getMatchLength()) {
+          byteStart += UTF8String.numBytesForFirstByte(src.getByte(byteStart));
+          cs += 1;
+        }
+        // Go to next match
+        end = stringSearch.next();
+        // Update byte positions
+        while (byteEnd < src.numBytes() && c < end) {
+          byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
+          c += 1;
+        }
+      }
+      buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart,
+        src.numBytes() - byteStart);
+      return buf.build();
+    }
+
+    private static UTF8String lowercaseReplace(final UTF8String src, final 
UTF8String search,
+        final UTF8String replace) {
+      if (src.numBytes() == 0 || search.numBytes() == 0) {
+        return src;
+      }
+      UTF8String lowercaseString = src.toLowerCase();
+      UTF8String lowercaseSearch = search.toLowerCase();
+
+      int start = 0;
+      int end = lowercaseString.indexOf(lowercaseSearch, 0);
+      if (end == -1) {
+        // Search string was not found, so string is unchanged.
+        return src;
+      }
+
+      // Initialize byte positions
+      int c = 0;
+      int byteStart = 0; // position in byte
+      int byteEnd = 0; // position in byte
+      while (byteEnd < src.numBytes() && c < end) {
+        byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
+        c += 1;
+      }
+
+      // At least one match was found. Estimate space needed for result.
+      // The 16x multiplier here is chosen to match commons-lang3's 
implementation.
+      int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16;
+      final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + 
increase);
+      while (end != -1) {
+        buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, 
byteEnd - byteStart);
+        buf.append(replace);
+        // Update character positions
+        start = end + lowercaseSearch.numChars();
+        end = lowercaseString.indexOf(lowercaseSearch, start);
+        // Update byte positions
+        byteStart = byteEnd + search.numBytes();
+        while (byteEnd < src.numBytes() && c < end) {
+          byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
+          c += 1;
+        }
+      }
+      buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart,
+        src.numBytes() - byteStart);
+      return buf.build();
+    }
+
     private static String toUpperCase(final String target, final int 
collationId) {
       ULocale locale = CollationFactory.fetchCollation(collationId)
               .collator.getLocale(ULocale.ACTUAL_LOCALE);
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 8ceeddb0c3dd..ca6198df2bbf 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -224,7 +224,7 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
    * Returns the number of bytes for a code point with the first byte as `b`
    * @param b The first byte of a code point
    */
-  private static int numBytesForFirstByte(final byte b) {
+  public static int numBytesForFirstByte(final byte b) {
     final int offset = b & 0xFF;
     byte numBytes = bytesOfCodePointInUTF8[offset];
     return (numBytes == 0) ? 1: numBytes; // Skip the first byte disallowed in 
UTF-8
@@ -382,7 +382,7 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
   /**
    * Returns the byte at position `i`.
    */
-  private byte getByte(int i) {
+  public byte getByte(int i) {
     return Platform.getByte(base, offset + i);
   }
 
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index 72edd3e06f9c..6c79fc821317 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -614,6 +614,44 @@ public class CollationSupportSuite {
     assertFindInSet("İo", "ab,i̇o,12", "UNICODE_CI", 2);
   }
 
+  private void assertReplace(String source, String search, String replace, 
String collationName,
+        String expected) throws SparkException {
+    UTF8String src = UTF8String.fromString(source);
+    UTF8String sear = UTF8String.fromString(search);
+    UTF8String repl = UTF8String.fromString(replace);
+    int collationId = CollationFactory.collationNameToId(collationName);
+    assertEquals(expected, CollationSupport.StringReplace
+      .exec(src, sear, repl, collationId).toString());
+  }
+
+  @Test
+  public void testReplace() throws SparkException {
+    assertReplace("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace");
+    assertReplace("replace", "pl", "", "UTF8_BINARY", "reace");
+    assertReplace("repl世ace", "Pl", "", "UTF8_BINARY", "repl世ace");
+    assertReplace("replace", "", "123", "UTF8_BINARY", "replace");
+    assertReplace("abcabc", "b", "12", "UTF8_BINARY", "a12ca12c");
+    assertReplace("abcdabcd", "bc", "", "UTF8_BINARY", "adad");
+    assertReplace("r世eplace", "pl", "xx", "UTF8_BINARY_LCASE", "r世exxace");
+    assertReplace("repl世ace", "PL", "AB", "UTF8_BINARY_LCASE", "reAB世ace");
+    assertReplace("Replace", "", "123", "UTF8_BINARY_LCASE", "Replace");
+    assertReplace("re世place", "世", "x", "UTF8_BINARY_LCASE", "rexplace");
+    assertReplace("abcaBc", "B", "12", "UTF8_BINARY_LCASE", "a12ca12c");
+    assertReplace("AbcdabCd", "Bc", "", "UTF8_BINARY_LCASE", "Adad");
+    assertReplace("re世place", "plx", "123", "UNICODE", "re世place");
+    assertReplace("世Replace", "re", "", "UNICODE", "世Replace");
+    assertReplace("replace世", "", "123", "UNICODE", "replace世");
+    assertReplace("aBc世abc", "b", "12", "UNICODE", "aBc世a12c");
+    assertReplace("abcdabcd", "bc", "", "UNICODE", "adad");
+    assertReplace("replace", "plx", "123", "UNICODE_CI", "replace");
+    assertReplace("Replace", "re", "", "UNICODE_CI", "place");
+    assertReplace("replace", "", "123", "UNICODE_CI", "replace");
+    assertReplace("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c");
+    assertReplace("a世Bcdabcd", "bC", "", "UNICODE_CI", "a世dad");
+    assertReplace("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx");
+    assertReplace("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy");
+  }
+
   // TODO: Test more collation-aware string expressions.
 
   /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
index c7ca5607481d..3ae251e56772 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
@@ -64,7 +64,7 @@ object CollationTypeCasts extends TypeCoercionRule {
 
     case otherExpr @ (
       _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: 
Greatest | _: Least |
-      _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask) =>
+      _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: 
StringReplace) =>
       val newChildren = collateToSingleType(otherExpr.children)
       otherExpr.withNewChildren(newChildren)
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 612082c56096..135345990e51 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -710,23 +710,25 @@ case class EndsWith(left: Expression, right: Expression) 
extends StringPredicate
 case class StringReplace(srcExpr: Expression, searchExpr: Expression, 
replaceExpr: Expression)
   extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
+  final lazy val collationId: Int = 
first.dataType.asInstanceOf[StringType].collationId
+
   def this(srcExpr: Expression, searchExpr: Expression) = {
     this(srcExpr, searchExpr, Literal(""))
   }
 
   override def nullSafeEval(srcEval: Any, searchEval: Any, replaceEval: Any): 
Any = {
-    srcEval.asInstanceOf[UTF8String].replace(
-      searchEval.asInstanceOf[UTF8String], 
replaceEval.asInstanceOf[UTF8String])
+    CollationSupport.StringReplace.exec(srcEval.asInstanceOf[UTF8String],
+      searchEval.asInstanceOf[UTF8String], 
replaceEval.asInstanceOf[UTF8String], collationId);
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    nullSafeCodeGen(ctx, ev, (src, search, replace) => {
-      s"""${ev.value} = $src.replace($search, $replace);"""
-    })
+    defineCodeGen(ctx, ev, (src, search, replace) =>
+      CollationSupport.StringReplace.genCode(src, search, replace, 
collationId))
   }
 
-  override def dataType: DataType = StringType
-  override def inputTypes: Seq[DataType] = Seq(StringType, StringType, 
StringType)
+  override def dataType: DataType = srcExpr.dataType
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation)
   override def first: Expression = srcExpr
   override def second: Expression = searchExpr
   override def third: Expression = replaceExpr
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index 2b6761475a43..305c51c0b703 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.util.CollationFactory
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, 
DataType, IntegerType, StringType}
@@ -217,6 +218,41 @@ class CollationStringExpressionsSuite
     assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
   }
 
+  test("Support Replace string expression with collation") {
+    case class ReplaceTestCase[R](source: String, search: String, replace: 
String,
+        c: String, result: R)
+    val testCases = Seq(
+      // scalastyle:off
+      ReplaceTestCase("r世eplace", "pl", "123", "UTF8_BINARY", "r世e123ace"),
+      ReplaceTestCase("repl世ace", "PL", "AB", "UTF8_BINARY_LCASE", "reAB世ace"),
+      ReplaceTestCase("abcdabcd", "bc", "", "UNICODE", "adad"),
+      ReplaceTestCase("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c"),
+      ReplaceTestCase("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"),
+      ReplaceTestCase("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx")
+      // scalastyle:on
+    )
+    testCases.foreach(t => {
+      val query = s"SELECT 
replace(collate('${t.source}','${t.c}'),collate('${t.search}'," +
+        s"'${t.c}'),collate('${t.replace}','${t.c}'))"
+      // Result & data type
+      checkAnswer(sql(query), Row(t.result))
+      assert(sql(query).schema.fields.head.dataType.sameType(
+        StringType(CollationFactory.collationNameToId(t.c))))
+      // Implicit casting
+      checkAnswer(sql(s"SELECT 
replace(collate('${t.source}','${t.c}'),'${t.search}'," +
+        s"'${t.replace}')"), Row(t.result))
+      checkAnswer(sql(s"SELECT 
replace('${t.source}',collate('${t.search}','${t.c}')," +
+        s"'${t.replace}')"), Row(t.result))
+      checkAnswer(sql(s"SELECT replace('${t.source}','${t.search}'," +
+        s"collate('${t.replace}','${t.c}'))"), Row(t.result))
+    })
+    // Collation mismatch
+    val collationMismatch = intercept[AnalysisException] {
+      sql("SELECT startswith(collate('abcde', 
'UTF8_BINARY_LCASE'),collate('C', 'UNICODE_CI'))")
+    }
+    assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+  }
+
   test("Support EndsWith string expression with collation") {
     // Supported collations
     case class EndsWithTestCase[R](l: String, r: String, c: String, result: R)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to