This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 21333f8c1fc0 [SPARK-47409][SQL] Add support for collation for 
StringTrim type of functions/expressions (for UTF8_BINARY & LCASE)
21333f8c1fc0 is described below

commit 21333f8c1fc01756e6708ad6ccf21f585fcb881d
Author: David Milicevic <david.milice...@databricks.com>
AuthorDate: Thu May 9 23:05:20 2024 +0800

    [SPARK-47409][SQL] Add support for collation for StringTrim type of 
functions/expressions (for UTF8_BINARY & LCASE)
    
    Recreating [original PR](https://github.com/apache/spark/pull/45749) 
because code has been reorganized in [this 
PR](https://github.com/apache/spark/pull/45978).
    
    ### What changes were proposed in this pull request?
    This PR is created to add support for collations to StringTrim family of 
functions/expressions, specifically:
    - `StringTrim`
    - `StringTrimBoth`
    - `StringTrimLeft`
    - `StringTrimRight`
    
    Changes:
    - `CollationSupport.java`
      - Add new `StringTrim`, `StringTrimLeft` and `StringTrimRight` classes 
with corresponding logic.
      - `CollationAwareUTF8String` - add new `trim`, `trimLeft` and `trimRight` 
methods that actually implement trim logic.
    - `UTF8String.java` - expose some of the methods publicly.
    - `stringExpressions.scala`
      - Change input types.
      - Change eval and code gen logic.
    - `CollationTypeCasts.scala` - add `StringTrim*` expressions to 
`CollationTypeCasts` rules.
    
    ### Why are the changes needed?
    We are incrementally adding collation support to a built-in string 
functions in Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes:
    - User should now be able to use non-default collations in string trim 
functions.
    
    ### How was this patch tested?
    Already existing tests + new unit/e2e tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #46206 from davidm-db/string-trim-functions.
    
    Authored-by: David Milicevic <david.milice...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalyst/util/CollationAwareUTF8String.java    | 470 ++++++++++++++++++
 .../spark/sql/catalyst/util/CollationSupport.java  | 534 ++++++++-------------
 .../org/apache/spark/unsafe/types/UTF8String.java  |   2 +-
 .../spark/unsafe/types/CollationSupportSuite.java  | 193 ++++++++
 .../sql/catalyst/analysis/CollationTypeCasts.scala |   2 +-
 .../catalyst/expressions/stringExpressions.scala   |  53 +-
 .../sql/CollationStringExpressionsSuite.scala      | 161 ++++++-
 7 files changed, 1054 insertions(+), 361 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
new file mode 100644
index 000000000000..ee0d611d7e65
--- /dev/null
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
@@ -0,0 +1,470 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.util;
+
+import com.ibm.icu.lang.UCharacter;
+import com.ibm.icu.text.BreakIterator;
+import com.ibm.icu.text.StringSearch;
+import com.ibm.icu.util.ULocale;
+
+import org.apache.spark.unsafe.UTF8StringBuilder;
+import org.apache.spark.unsafe.types.UTF8String;
+
+import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
+import static org.apache.spark.unsafe.Platform.copyMemory;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Utility class for collation-aware UTF8String operations.
+ */
+public class CollationAwareUTF8String {
+  public static UTF8String replace(final UTF8String src, final UTF8String 
search,
+      final UTF8String replace, final int collationId) {
+    // This collation aware implementation is based on existing implementation 
on UTF8String
+    if (src.numBytes() == 0 || search.numBytes() == 0) {
+      return src;
+    }
+
+    StringSearch stringSearch = CollationFactory.getStringSearch(src, search, 
collationId);
+
+    // Find the first occurrence of the search string.
+    int end = stringSearch.next();
+    if (end == StringSearch.DONE) {
+      // Search string was not found, so string is unchanged.
+      return src;
+    }
+
+    // Initialize byte positions
+    int c = 0;
+    int byteStart = 0; // position in byte
+    int byteEnd = 0; // position in byte
+    while (byteEnd < src.numBytes() && c < end) {
+      byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
+      c += 1;
+    }
+
+    // At least one match was found. Estimate space needed for result.
+    // The 16x multiplier here is chosen to match commons-lang3's 
implementation.
+    int increase = Math.max(0, Math.abs(replace.numBytes() - 
search.numBytes())) * 16;
+    final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + 
increase);
+    while (end != StringSearch.DONE) {
+      buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, 
byteEnd - byteStart);
+      buf.append(replace);
+
+      // Move byteStart to the beginning of the current match
+      byteStart = byteEnd;
+      int cs = c;
+      // Move cs to the end of the current match
+      // This is necessary because the search string may contain 
'multi-character' characters
+      while (byteStart < src.numBytes() && cs < c + 
stringSearch.getMatchLength()) {
+        byteStart += UTF8String.numBytesForFirstByte(src.getByte(byteStart));
+        cs += 1;
+      }
+      // Go to next match
+      end = stringSearch.next();
+      // Update byte positions
+      while (byteEnd < src.numBytes() && c < end) {
+        byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
+        c += 1;
+      }
+    }
+    buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart,
+      src.numBytes() - byteStart);
+    return buf.build();
+  }
+
+  public static UTF8String lowercaseReplace(final UTF8String src, final 
UTF8String search,
+      final UTF8String replace) {
+    if (src.numBytes() == 0 || search.numBytes() == 0) {
+      return src;
+    }
+    UTF8String lowercaseString = src.toLowerCase();
+    UTF8String lowercaseSearch = search.toLowerCase();
+
+    int start = 0;
+    int end = lowercaseString.indexOf(lowercaseSearch, 0);
+    if (end == -1) {
+      // Search string was not found, so string is unchanged.
+      return src;
+    }
+
+    // Initialize byte positions
+    int c = 0;
+    int byteStart = 0; // position in byte
+    int byteEnd = 0; // position in byte
+    while (byteEnd < src.numBytes() && c < end) {
+      byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
+      c += 1;
+    }
+
+    // At least one match was found. Estimate space needed for result.
+    // The 16x multiplier here is chosen to match commons-lang3's 
implementation.
+    int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16;
+    final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + 
increase);
+    while (end != -1) {
+      buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, 
byteEnd - byteStart);
+      buf.append(replace);
+      // Update character positions
+      start = end + lowercaseSearch.numChars();
+      end = lowercaseString.indexOf(lowercaseSearch, start);
+      // Update byte positions
+      byteStart = byteEnd + search.numBytes();
+      while (byteEnd < src.numBytes() && c < end) {
+        byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
+        c += 1;
+      }
+    }
+    buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart,
+      src.numBytes() - byteStart);
+    return buf.build();
+  }
+
+  public static String toUpperCase(final String target, final int collationId) 
{
+    ULocale locale = CollationFactory.fetchCollation(collationId)
+      .collator.getLocale(ULocale.ACTUAL_LOCALE);
+    return UCharacter.toUpperCase(locale, target);
+  }
+
+  public static String toLowerCase(final String target, final int collationId) 
{
+    ULocale locale = CollationFactory.fetchCollation(collationId)
+      .collator.getLocale(ULocale.ACTUAL_LOCALE);
+    return UCharacter.toLowerCase(locale, target);
+  }
+
+  public static String toTitleCase(final String target, final int collationId) 
{
+    ULocale locale = CollationFactory.fetchCollation(collationId)
+      .collator.getLocale(ULocale.ACTUAL_LOCALE);
+    return UCharacter.toTitleCase(locale, target, 
BreakIterator.getWordInstance(locale));
+  }
+
+  public static int findInSet(final UTF8String match, final UTF8String set, 
int collationId) {
+    if (match.contains(UTF8String.fromString(","))) {
+      return 0;
+    }
+
+    String setString = set.toString();
+    StringSearch stringSearch = CollationFactory.getStringSearch(setString, 
match.toString(),
+      collationId);
+
+    int wordStart = 0;
+    while ((wordStart = stringSearch.next()) != StringSearch.DONE) {
+      boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) 
== ',';
+      boolean isValidEnd = wordStart + stringSearch.getMatchLength() == 
setString.length()
+        || setString.charAt(wordStart + stringSearch.getMatchLength()) == ',';
+
+      if (isValidStart && isValidEnd) {
+        int pos = 0;
+        for (int i = 0; i < setString.length() && i < wordStart; i++) {
+          if (setString.charAt(i) == ',') {
+            pos++;
+          }
+        }
+
+        return pos + 1;
+      }
+    }
+
+    return 0;
+  }
+
+  public static int indexOf(final UTF8String target, final UTF8String pattern,
+      final int start, final int collationId) {
+    if (pattern.numBytes() == 0) {
+      return 0;
+    }
+
+    StringSearch stringSearch = CollationFactory.getStringSearch(target, 
pattern, collationId);
+    stringSearch.setIndex(start);
+
+    return stringSearch.next();
+  }
+
+  public static int find(UTF8String target, UTF8String pattern, int start,
+      int collationId) {
+    assert (pattern.numBytes() > 0);
+
+    StringSearch stringSearch = CollationFactory.getStringSearch(target, 
pattern, collationId);
+    // Set search start position (start from character at start position)
+    stringSearch.setIndex(target.bytePosToChar(start));
+
+    // Return either the byte position or -1 if not found
+    return target.charPosToByte(stringSearch.next());
+  }
+
+  public static UTF8String subStringIndex(final UTF8String string, final 
UTF8String delimiter,
+      int count, final int collationId) {
+    if (delimiter.numBytes() == 0 || count == 0 || string.numBytes() == 0) {
+      return UTF8String.EMPTY_UTF8;
+    }
+    if (count > 0) {
+      int idx = -1;
+      while (count > 0) {
+        idx = find(string, delimiter, idx + 1, collationId);
+        if (idx >= 0) {
+          count --;
+        } else {
+          // can not find enough delim
+          return string;
+        }
+      }
+      if (idx == 0) {
+        return UTF8String.EMPTY_UTF8;
+      }
+      byte[] bytes = new byte[idx];
+      copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, 
BYTE_ARRAY_OFFSET, idx);
+      return UTF8String.fromBytes(bytes);
+
+    } else {
+      count = -count;
+
+      StringSearch stringSearch = CollationFactory
+        .getStringSearch(string, delimiter, collationId);
+
+      int start = string.numChars() - 1;
+      int lastMatchLength = 0;
+      int prevStart = -1;
+      while (count > 0) {
+        stringSearch.reset();
+        prevStart = -1;
+        int matchStart = stringSearch.next();
+        lastMatchLength = stringSearch.getMatchLength();
+        while (matchStart <= start) {
+          if (matchStart != StringSearch.DONE) {
+            // Found a match, update the start position
+            prevStart = matchStart;
+            matchStart = stringSearch.next();
+          } else {
+            break;
+          }
+        }
+
+        if (prevStart == -1) {
+          // can not find enough delim
+          return string;
+        } else {
+          start = prevStart - 1;
+          count--;
+        }
+      }
+
+      int resultStart = prevStart + lastMatchLength;
+      if (resultStart == string.numChars()) {
+        return UTF8String.EMPTY_UTF8;
+      }
+
+      return string.substring(resultStart, string.numChars());
+    }
+  }
+
+  public static UTF8String lowercaseSubStringIndex(final UTF8String string,
+      final UTF8String delimiter, int count) {
+    if (delimiter.numBytes() == 0 || count == 0) {
+      return UTF8String.EMPTY_UTF8;
+    }
+
+    UTF8String lowercaseString = string.toLowerCase();
+    UTF8String lowercaseDelimiter = delimiter.toLowerCase();
+
+    if (count > 0) {
+      int idx = -1;
+      while (count > 0) {
+        idx = lowercaseString.find(lowercaseDelimiter, idx + 1);
+        if (idx >= 0) {
+          count--;
+        } else {
+          // can not find enough delim
+          return string;
+        }
+      }
+      if (idx == 0) {
+        return UTF8String.EMPTY_UTF8;
+      }
+      byte[] bytes = new byte[idx];
+      copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, 
BYTE_ARRAY_OFFSET, idx);
+      return UTF8String.fromBytes(bytes);
+
+    } else {
+      int idx = string.numBytes() - delimiter.numBytes() + 1;
+      count = -count;
+      while (count > 0) {
+        idx = lowercaseString.rfind(lowercaseDelimiter, idx - 1);
+        if (idx >= 0) {
+          count--;
+        } else {
+          // can not find enough delim
+          return string;
+        }
+      }
+      if (idx + delimiter.numBytes() == string.numBytes()) {
+        return UTF8String.EMPTY_UTF8;
+      }
+      int size = string.numBytes() - delimiter.numBytes() - idx;
+      byte[] bytes = new byte[size];
+      copyMemory(string.getBaseObject(), string.getBaseOffset() + idx + 
delimiter.numBytes(),
+        bytes, BYTE_ARRAY_OFFSET, size);
+      return UTF8String.fromBytes(bytes);
+    }
+  }
+
+  public static Map<String, String> getCollationAwareDict(UTF8String string,
+      Map<String, String> dict, int collationId) {
+    String srcStr = string.toString();
+
+    Map<String, String> collationAwareDict = new HashMap<>();
+    for (String key : dict.keySet()) {
+      StringSearch stringSearch =
+        CollationFactory.getStringSearch(string, UTF8String.fromString(key), 
collationId);
+
+      int pos = 0;
+      while ((pos = stringSearch.next()) != StringSearch.DONE) {
+        int codePoint = srcStr.codePointAt(pos);
+        int charCount = Character.charCount(codePoint);
+        String newKey = srcStr.substring(pos, pos + charCount);
+
+        boolean exists = false;
+        for (String existingKey : collationAwareDict.keySet()) {
+          if (stringSearch.getCollator().compare(existingKey, newKey) == 0) {
+            collationAwareDict.put(newKey, 
collationAwareDict.get(existingKey));
+            exists = true;
+            break;
+          }
+        }
+
+        if (!exists) {
+          collationAwareDict.put(newKey, dict.get(key));
+        }
+      }
+    }
+
+    return collationAwareDict;
+  }
+
+  public static UTF8String lowercaseTrim(
+      final UTF8String srcString,
+      final UTF8String trimString) {
+    // Matching UTF8String behavior for null `trimString`.
+    if (trimString == null) {
+      return null;
+    }
+
+    UTF8String leftTrimmed = lowercaseTrimLeft(srcString, trimString);
+    return lowercaseTrimRight(leftTrimmed, trimString);
+  }
+
+  public static UTF8String lowercaseTrimLeft(
+      final UTF8String srcString,
+      final UTF8String trimString) {
+    // Matching UTF8String behavior for null `trimString`.
+    if (trimString == null) {
+      return null;
+    }
+
+    // The searching byte position in the srcString.
+    int searchIdx = 0;
+    // The byte position of a first non-matching character in the srcString.
+    int trimByteIdx = 0;
+    // Number of bytes in srcString.
+    int numBytes = srcString.numBytes();
+    // Convert trimString to lowercase, so it can be searched properly.
+    UTF8String lowercaseTrimString = trimString.toLowerCase();
+
+    while (searchIdx < numBytes) {
+      UTF8String searchChar = srcString.copyUTF8String(
+        searchIdx,
+        searchIdx + 
UTF8String.numBytesForFirstByte(srcString.getByte(searchIdx)) - 1);
+      int searchCharBytes = searchChar.numBytes();
+
+      // Try to find the matching for the searchChar in the trimString.
+      if (lowercaseTrimString.find(searchChar.toLowerCase(), 0) >= 0) {
+        trimByteIdx += searchCharBytes;
+        searchIdx += searchCharBytes;
+      } else {
+        // No matching, exit the search.
+        break;
+      }
+    }
+
+    if (searchIdx == 0) {
+      // Nothing trimmed - return original string (not converted to lowercase).
+      return srcString;
+    }
+    if (trimByteIdx >= numBytes) {
+      // Everything trimmed.
+      return UTF8String.EMPTY_UTF8;
+    }
+    return srcString.copyUTF8String(trimByteIdx, numBytes - 1);
+  }
+
+  public static UTF8String lowercaseTrimRight(
+      final UTF8String srcString,
+      final UTF8String trimString) {
+    // Matching UTF8String behavior for null `trimString`.
+    if (trimString == null) {
+      return null;
+    }
+
+    // Number of bytes iterated from the srcString.
+    int byteIdx = 0;
+    // Number of characters iterated from the srcString.
+    int numChars = 0;
+    // Number of bytes in srcString.
+    int numBytes = srcString.numBytes();
+    // Array of character length for the srcString.
+    int[] stringCharLen = new int[numBytes];
+    // Array of the first byte position for each character in the srcString.
+    int[] stringCharPos = new int[numBytes];
+    // Convert trimString to lowercase, so it can be searched properly.
+    UTF8String lowercaseTrimString = trimString.toLowerCase();
+
+    // Build the position and length array.
+    while (byteIdx < numBytes) {
+      stringCharPos[numChars] = byteIdx;
+      stringCharLen[numChars] = 
UTF8String.numBytesForFirstByte(srcString.getByte(byteIdx));
+      byteIdx += stringCharLen[numChars];
+      numChars++;
+    }
+
+    // Index trimEnd points to the first no matching byte position from the 
right side of
+    //  the source string.
+    int trimByteIdx = numBytes - 1;
+
+    while (numChars > 0) {
+      UTF8String searchChar = srcString.copyUTF8String(
+        stringCharPos[numChars - 1],
+        stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1);
+
+      if(lowercaseTrimString.find(searchChar.toLowerCase(), 0) >= 0) {
+        trimByteIdx -= stringCharLen[numChars - 1];
+        numChars--;
+      } else {
+        break;
+      }
+    }
+
+    if (trimByteIdx == numBytes - 1) {
+      // Nothing trimmed.
+      return srcString;
+    }
+    if (trimByteIdx < 0) {
+      // Everything trimmed.
+      return UTF8String.EMPTY_UTF8;
+    }
+    return srcString.copyUTF8String(0, trimByteIdx);
+  }
+}
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
index b77671cee90b..bea3dc08b448 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -16,23 +16,15 @@
  */
 package org.apache.spark.sql.catalyst.util;
 
-import com.ibm.icu.lang.UCharacter;
-import com.ibm.icu.text.BreakIterator;
 import com.ibm.icu.text.StringSearch;
-import com.ibm.icu.util.ULocale;
 
-import org.apache.spark.unsafe.UTF8StringBuilder;
 import org.apache.spark.unsafe.types.UTF8String;
 
 import java.util.ArrayList;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.regex.Pattern;
 
-import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
-import static org.apache.spark.unsafe.Platform.copyMemory;
-
 /**
  * Static entry point for collation-aware expressions (StringExpressions, 
RegexpExpressions, and
  * other expressions that require custom collation support), as well as 
private utility methods for
@@ -441,7 +433,7 @@ public final class CollationSupport {
       return string.toLowerCase().indexOf(substring.toLowerCase(), start);
     }
     public static int execICU(final UTF8String string, final UTF8String 
substring, final int start,
-                              final int collationId) {
+        final int collationId) {
       return CollationAwareUTF8String.indexOf(string, substring, start, 
collationId);
     }
   }
@@ -535,6 +527,201 @@ public final class CollationSupport {
     }
   }
 
+  public static class StringTrim {
+    public static UTF8String exec(
+        final UTF8String srcString,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(srcString);
+      } else {
+        return execLowercase(srcString);
+      }
+    }
+    public static UTF8String exec(
+        final UTF8String srcString,
+        final UTF8String trimString,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(srcString, trimString);
+      } else {
+        return execLowercase(srcString, trimString);
+      }
+    }
+    public static String genCode(
+        final String srcString,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.StringTrim.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s)", srcString);
+      } {
+        return String.format(expr + "Lowercase(%s)", srcString);
+      }
+    }
+    public static String genCode(
+        final String srcString,
+        final String trimString,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.StringTrim.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s, %s)", srcString, trimString);
+      } else {
+        return String.format(expr + "Lowercase(%s, %s)", srcString, 
trimString);
+      }
+    }
+    public static UTF8String execBinary(
+        final UTF8String srcString) {
+      return srcString.trim();
+    }
+    public static UTF8String execBinary(
+        final UTF8String srcString,
+        final UTF8String trimString) {
+      return srcString.trim(trimString);
+    }
+    public static UTF8String execLowercase(
+        final UTF8String srcString) {
+      return srcString.trim();
+    }
+    public static UTF8String execLowercase(
+        final UTF8String srcString,
+        final UTF8String trimString) {
+      return CollationAwareUTF8String.lowercaseTrim(srcString, trimString);
+    }
+  }
+
+  public static class StringTrimLeft {
+    public static UTF8String exec(
+        final UTF8String srcString,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(srcString);
+      } else {
+        return execLowercase(srcString);
+      }
+    }
+    public static UTF8String exec(
+        final UTF8String srcString,
+        final UTF8String trimString,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(srcString, trimString);
+      } else {
+        return execLowercase(srcString, trimString);
+      }
+    }
+    public static String genCode(
+        final String srcString,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.StringTrimLeft.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s)", srcString);
+      } else {
+        return String.format(expr + "Lowercase(%s)", srcString);
+      }
+    }
+    public static String genCode(
+        final String srcString,
+        final String trimString,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.StringTrimLeft.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s, %s)", srcString, trimString);
+      } else {
+        return String.format(expr + "Lowercase(%s, %s)", srcString, 
trimString);
+      }
+    }
+    public static UTF8String execBinary(
+        final UTF8String srcString) {
+      return srcString.trimLeft();
+    }
+    public static UTF8String execBinary(
+        final UTF8String srcString,
+        final UTF8String trimString) {
+      return srcString.trimLeft(trimString);
+    }
+    public static UTF8String execLowercase(
+        final UTF8String srcString) {
+      return srcString.trimLeft();
+    }
+    public static UTF8String execLowercase(
+        final UTF8String srcString,
+        final UTF8String trimString) {
+      return CollationAwareUTF8String.lowercaseTrimLeft(srcString, trimString);
+    }
+  }
+
+  public static class StringTrimRight {
+    public static UTF8String exec(
+        final UTF8String srcString,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(srcString);
+      } else {
+        return execLowercase(srcString);
+      }
+    }
+    public static UTF8String exec(
+        final UTF8String srcString,
+        final UTF8String trimString,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(srcString, trimString);
+      } else {
+        return execLowercase(srcString, trimString);
+      }
+    }
+    public static String genCode(
+        final String srcString,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.StringTrimRight.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s)", srcString);
+      } else {
+        return String.format(expr + "Lowercase(%s)", srcString);
+      }
+    }
+    public static String genCode(
+        final String srcString,
+        final String trimString,
+        final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.StringTrimRight.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s, %s)", srcString, trimString);
+      } else {
+        return String.format(expr + "Lowercase(%s, %s)", srcString, 
trimString);
+      }
+    }
+    public static UTF8String execBinary(
+        final UTF8String srcString) {
+      return srcString.trimRight();
+    }
+    public static UTF8String execBinary(
+        final UTF8String srcString,
+        final UTF8String trimString) {
+      return srcString.trimRight(trimString);
+    }
+    public static UTF8String execLowercase(
+        final UTF8String srcString) {
+      return srcString.trimRight();
+    }
+    public static UTF8String execLowercase(
+        final UTF8String srcString,
+        final UTF8String trimString) {
+      return CollationAwareUTF8String.lowercaseTrimRight(srcString, 
trimString);
+    }
+  }
+
   // TODO: Add more collation-aware string expressions.
 
   /**
@@ -566,333 +753,4 @@ public final class CollationSupport {
 
   // TODO: Add other collation-aware expressions.
 
-  /**
-   * Utility class for collation-aware UTF8String operations.
-   */
-
-  private static class CollationAwareUTF8String {
-
-    private static UTF8String replace(final UTF8String src, final UTF8String 
search,
-        final UTF8String replace, final int collationId) {
-      // This collation aware implementation is based on existing 
implementation on UTF8String
-      if (src.numBytes() == 0 || search.numBytes() == 0) {
-        return src;
-      }
-
-      StringSearch stringSearch = CollationFactory.getStringSearch(src, 
search, collationId);
-
-      // Find the first occurrence of the search string.
-      int end = stringSearch.next();
-      if (end == StringSearch.DONE) {
-        // Search string was not found, so string is unchanged.
-        return src;
-      }
-
-      // Initialize byte positions
-      int c = 0;
-      int byteStart = 0; // position in byte
-      int byteEnd = 0; // position in byte
-      while (byteEnd < src.numBytes() && c < end) {
-        byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
-        c += 1;
-      }
-
-      // At least one match was found. Estimate space needed for result.
-      // The 16x multiplier here is chosen to match commons-lang3's 
implementation.
-      int increase = Math.max(0, Math.abs(replace.numBytes() - 
search.numBytes())) * 16;
-      final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + 
increase);
-      while (end != StringSearch.DONE) {
-        buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, 
byteEnd - byteStart);
-        buf.append(replace);
-
-        // Move byteStart to the beginning of the current match
-        byteStart = byteEnd;
-        int cs = c;
-        // Move cs to the end of the current match
-        // This is necessary because the search string may contain 
'multi-character' characters
-        while (byteStart < src.numBytes() && cs < c + 
stringSearch.getMatchLength()) {
-          byteStart += UTF8String.numBytesForFirstByte(src.getByte(byteStart));
-          cs += 1;
-        }
-        // Go to next match
-        end = stringSearch.next();
-        // Update byte positions
-        while (byteEnd < src.numBytes() && c < end) {
-          byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
-          c += 1;
-        }
-      }
-      buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart,
-        src.numBytes() - byteStart);
-      return buf.build();
-    }
-
-    private static UTF8String lowercaseReplace(final UTF8String src, final 
UTF8String search,
-        final UTF8String replace) {
-      if (src.numBytes() == 0 || search.numBytes() == 0) {
-        return src;
-      }
-      UTF8String lowercaseString = src.toLowerCase();
-      UTF8String lowercaseSearch = search.toLowerCase();
-
-      int start = 0;
-      int end = lowercaseString.indexOf(lowercaseSearch, 0);
-      if (end == -1) {
-        // Search string was not found, so string is unchanged.
-        return src;
-      }
-
-      // Initialize byte positions
-      int c = 0;
-      int byteStart = 0; // position in byte
-      int byteEnd = 0; // position in byte
-      while (byteEnd < src.numBytes() && c < end) {
-        byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
-        c += 1;
-      }
-
-      // At least one match was found. Estimate space needed for result.
-      // The 16x multiplier here is chosen to match commons-lang3's 
implementation.
-      int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16;
-      final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + 
increase);
-      while (end != -1) {
-        buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, 
byteEnd - byteStart);
-        buf.append(replace);
-        // Update character positions
-        start = end + lowercaseSearch.numChars();
-        end = lowercaseString.indexOf(lowercaseSearch, start);
-        // Update byte positions
-        byteStart = byteEnd + search.numBytes();
-        while (byteEnd < src.numBytes() && c < end) {
-          byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
-          c += 1;
-        }
-      }
-      buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart,
-        src.numBytes() - byteStart);
-      return buf.build();
-    }
-
-    private static String toUpperCase(final String target, final int 
collationId) {
-      ULocale locale = CollationFactory.fetchCollation(collationId)
-              .collator.getLocale(ULocale.ACTUAL_LOCALE);
-      return UCharacter.toUpperCase(locale, target);
-    }
-
-    private static String toLowerCase(final String target, final int 
collationId) {
-      ULocale locale = CollationFactory.fetchCollation(collationId)
-              .collator.getLocale(ULocale.ACTUAL_LOCALE);
-      return UCharacter.toLowerCase(locale, target);
-    }
-
-    private static String toTitleCase(final String target, final int 
collationId) {
-      ULocale locale = CollationFactory.fetchCollation(collationId)
-              .collator.getLocale(ULocale.ACTUAL_LOCALE);
-      return UCharacter.toTitleCase(locale, target, 
BreakIterator.getWordInstance(locale));
-    }
-
-    private static int findInSet(final UTF8String match, final UTF8String set, 
int collationId) {
-      if (match.contains(UTF8String.fromString(","))) {
-        return 0;
-      }
-
-      String setString = set.toString();
-      StringSearch stringSearch = CollationFactory.getStringSearch(setString, 
match.toString(),
-        collationId);
-
-      int wordStart = 0;
-      while ((wordStart = stringSearch.next()) != StringSearch.DONE) {
-        boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 
1) == ',';
-        boolean isValidEnd = wordStart + stringSearch.getMatchLength() == 
setString.length()
-                || setString.charAt(wordStart + stringSearch.getMatchLength()) 
== ',';
-
-        if (isValidStart && isValidEnd) {
-          int pos = 0;
-          for (int i = 0; i < setString.length() && i < wordStart; i++) {
-            if (setString.charAt(i) == ',') {
-              pos++;
-            }
-          }
-
-          return pos + 1;
-        }
-      }
-
-      return 0;
-    }
-
-    private static int indexOf(final UTF8String target, final UTF8String 
pattern,
-        final int start, final int collationId) {
-      if (pattern.numBytes() == 0) {
-        return 0;
-      }
-
-      StringSearch stringSearch = CollationFactory.getStringSearch(target, 
pattern, collationId);
-      stringSearch.setIndex(start);
-
-      return stringSearch.next();
-    }
-
-    private static int find(UTF8String target, UTF8String pattern, int start,
-        int collationId) {
-      assert (pattern.numBytes() > 0);
-
-      StringSearch stringSearch = CollationFactory.getStringSearch(target, 
pattern, collationId);
-      // Set search start position (start from character at start position)
-      stringSearch.setIndex(target.bytePosToChar(start));
-
-      // Return either the byte position or -1 if not found
-      return target.charPosToByte(stringSearch.next());
-    }
-
-    private static UTF8String subStringIndex(final UTF8String string, final 
UTF8String delimiter,
-        int count, final int collationId) {
-      if (delimiter.numBytes() == 0 || count == 0 || string.numBytes() == 0) {
-        return UTF8String.EMPTY_UTF8;
-      }
-      if (count > 0) {
-        int idx = -1;
-        while (count > 0) {
-          idx = find(string, delimiter, idx + 1, collationId);
-          if (idx >= 0) {
-            count --;
-          } else {
-            // can not find enough delim
-            return string;
-          }
-        }
-        if (idx == 0) {
-          return UTF8String.EMPTY_UTF8;
-        }
-        byte[] bytes = new byte[idx];
-        copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, 
BYTE_ARRAY_OFFSET, idx);
-        return UTF8String.fromBytes(bytes);
-
-      } else {
-        count = -count;
-
-        StringSearch stringSearch = CollationFactory
-          .getStringSearch(string, delimiter, collationId);
-
-        int start = string.numChars() - 1;
-        int lastMatchLength = 0;
-        int prevStart = -1;
-        while (count > 0) {
-          stringSearch.reset();
-          prevStart = -1;
-          int matchStart = stringSearch.next();
-          lastMatchLength = stringSearch.getMatchLength();
-          while (matchStart <= start) {
-            if (matchStart != StringSearch.DONE) {
-              // Found a match, update the start position
-              prevStart = matchStart;
-              matchStart = stringSearch.next();
-            } else {
-              break;
-            }
-          }
-
-          if (prevStart == -1) {
-            // can not find enough delim
-            return string;
-          } else {
-            start = prevStart - 1;
-            count--;
-          }
-        }
-
-        int resultStart = prevStart + lastMatchLength;
-        if (resultStart == string.numChars()) {
-          return UTF8String.EMPTY_UTF8;
-        }
-
-        return string.substring(resultStart, string.numChars());
-      }
-    }
-
-    private static UTF8String lowercaseSubStringIndex(final UTF8String string,
-      final UTF8String delimiter, int count) {
-      if (delimiter.numBytes() == 0 || count == 0) {
-        return UTF8String.EMPTY_UTF8;
-      }
-
-      UTF8String lowercaseString = string.toLowerCase();
-      UTF8String lowercaseDelimiter = delimiter.toLowerCase();
-
-      if (count > 0) {
-        int idx = -1;
-        while (count > 0) {
-          idx = lowercaseString.find(lowercaseDelimiter, idx + 1);
-          if (idx >= 0) {
-            count --;
-          } else {
-            // can not find enough delim
-            return string;
-          }
-        }
-        if (idx == 0) {
-          return UTF8String.EMPTY_UTF8;
-        }
-        byte[] bytes = new byte[idx];
-        copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, 
BYTE_ARRAY_OFFSET, idx);
-        return UTF8String.fromBytes(bytes);
-
-      } else {
-        int idx = string.numBytes() - delimiter.numBytes() + 1;
-        count = -count;
-        while (count > 0) {
-          idx = lowercaseString.rfind(lowercaseDelimiter, idx - 1);
-          if (idx >= 0) {
-            count --;
-          } else {
-            // can not find enough delim
-            return string;
-          }
-        }
-        if (idx + delimiter.numBytes() == string.numBytes()) {
-          return UTF8String.EMPTY_UTF8;
-        }
-        int size = string.numBytes() - delimiter.numBytes() - idx;
-        byte[] bytes = new byte[size];
-        copyMemory(string.getBaseObject(), string.getBaseOffset() + idx + 
delimiter.numBytes(),
-                bytes, BYTE_ARRAY_OFFSET, size);
-        return UTF8String.fromBytes(bytes);
-      }
-    }
-
-    private static Map<String, String> getCollationAwareDict(UTF8String string,
-        Map<String, String> dict, int collationId) {
-      String srcStr = string.toString();
-
-      Map<String, String> collationAwareDict = new HashMap<>();
-      for (String key : dict.keySet()) {
-        StringSearch stringSearch =
-          CollationFactory.getStringSearch(string, UTF8String.fromString(key), 
collationId);
-
-        int pos = 0;
-        while ((pos = stringSearch.next()) != StringSearch.DONE) {
-          int codePoint = srcStr.codePointAt(pos);
-          int charCount = Character.charCount(codePoint);
-          String newKey = srcStr.substring(pos, pos + charCount);
-
-          boolean exists = false;
-          for (String existingKey : collationAwareDict.keySet()) {
-            if (stringSearch.getCollator().compare(existingKey, newKey) == 0) {
-              collationAwareDict.put(newKey, 
collationAwareDict.get(existingKey));
-              exists = true;
-              break;
-            }
-          }
-
-          if (!exists) {
-            collationAwareDict.put(newKey, dict.get(key));
-          }
-        }
-      }
-
-      return collationAwareDict;
-    }
-
-  }
-
 }
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 2a5d14580353..20b26b6ebc5a 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -646,7 +646,7 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
    * @param end the end position of the current UTF8String in bytes.
    * @return a new UTF8String in the position of [start, end] of current 
UTF8String bytes.
    */
-  private UTF8String copyUTF8String(int start, int end) {
+  public UTF8String copyUTF8String(int start, int end) {
     int len = end - start + 1;
     byte[] newBytes = new byte[len];
     copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len);
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index 2f05b9ad88c9..7fc3c4e349c3 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -800,6 +800,199 @@ public class CollationSupportSuite {
     assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", -4, "UNICODE_CI", 
"i̇o12i̇oİo");
   }
 
+  private void assertStringTrim(
+      String collation,
+      String sourceString,
+      String trimString,
+      String expectedResultString) throws SparkException {
+    int collationId = CollationFactory.collationNameToId(collation);
+    String result;
+
+    if (trimString == null) {
+      result = CollationSupport.StringTrim.exec(
+        UTF8String.fromString(sourceString), collationId).toString();
+    } else {
+      result = CollationSupport.StringTrim.exec(
+        UTF8String
+          .fromString(sourceString), UTF8String.fromString(trimString), 
collationId)
+          .toString();
+    }
+
+    assertEquals(expectedResultString, result);
+  }
+
+  private void assertStringTrimLeft(
+      String collation,
+      String sourceString,
+      String trimString,
+      String expectedResultString) throws SparkException {
+    int collationId = CollationFactory.collationNameToId(collation);
+    String result;
+
+    if (trimString == null) {
+      result = CollationSupport.StringTrimLeft.exec(
+        UTF8String.fromString(sourceString), collationId).toString();
+    } else {
+      result = CollationSupport.StringTrimLeft.exec(
+        UTF8String
+          .fromString(sourceString), UTF8String.fromString(trimString), 
collationId)
+          .toString();
+    }
+
+    assertEquals(expectedResultString, result);
+  }
+
+  private void assertStringTrimRight(
+      String collation,
+      String sourceString,
+      String trimString,
+      String expectedResultString) throws SparkException {
+    int collationId = CollationFactory.collationNameToId(collation);
+    String result;
+
+    if (trimString == null) {
+      result = CollationSupport.StringTrimRight.exec(
+        UTF8String.fromString(sourceString), collationId).toString();
+    } else {
+      result = CollationSupport.StringTrimRight.exec(
+        UTF8String
+          .fromString(sourceString), UTF8String.fromString(trimString), 
collationId)
+          .toString();
+    }
+
+    assertEquals(expectedResultString, result);
+  }
+
+  @Test
+  public void testStringTrim() throws SparkException {
+    assertStringTrim("UTF8_BINARY", "asd", null, "asd");
+    assertStringTrim("UTF8_BINARY", "  asd  ", null, "asd");
+    assertStringTrim("UTF8_BINARY", " a世a ", null, "a世a");
+    assertStringTrim("UTF8_BINARY", "asd", "x", "asd");
+    assertStringTrim("UTF8_BINARY", "xxasdxx", "x", "asd");
+    assertStringTrim("UTF8_BINARY", "xa世ax", "x", "a世a");
+
+    assertStringTrimLeft("UTF8_BINARY", "asd", null, "asd");
+    assertStringTrimLeft("UTF8_BINARY", "  asd  ", null, "asd  ");
+    assertStringTrimLeft("UTF8_BINARY", " a世a ", null, "a世a ");
+    assertStringTrimLeft("UTF8_BINARY", "asd", "x", "asd");
+    assertStringTrimLeft("UTF8_BINARY", "xxasdxx", "x", "asdxx");
+    assertStringTrimLeft("UTF8_BINARY", "xa世ax", "x", "a世ax");
+
+    assertStringTrimRight("UTF8_BINARY", "asd", null, "asd");
+    assertStringTrimRight("UTF8_BINARY", "  asd  ", null, "  asd");
+    assertStringTrimRight("UTF8_BINARY", " a世a ", null, " a世a");
+    assertStringTrimRight("UTF8_BINARY", "asd", "x", "asd");
+    assertStringTrimRight("UTF8_BINARY", "xxasdxx", "x", "xxasd");
+    assertStringTrimRight("UTF8_BINARY", "xa世ax", "x", "xa世a");
+
+    assertStringTrim("UTF8_BINARY_LCASE", "asd", null, "asd");
+    assertStringTrim("UTF8_BINARY_LCASE", "  asd  ", null, "asd");
+    assertStringTrim("UTF8_BINARY_LCASE", " a世a ", null, "a世a");
+    assertStringTrim("UTF8_BINARY_LCASE", "asd", "x", "asd");
+    assertStringTrim("UTF8_BINARY_LCASE", "xxasdxx", "x", "asd");
+    assertStringTrim("UTF8_BINARY_LCASE", "xa世ax", "x", "a世a");
+
+    assertStringTrimLeft("UTF8_BINARY_LCASE", "asd", null, "asd");
+    assertStringTrimLeft("UTF8_BINARY_LCASE", "  asd  ", null, "asd  ");
+    assertStringTrimLeft("UTF8_BINARY_LCASE", " a世a ", null, "a世a ");
+    assertStringTrimLeft("UTF8_BINARY_LCASE", "asd", "x", "asd");
+    assertStringTrimLeft("UTF8_BINARY_LCASE", "xxasdxx", "x", "asdxx");
+    assertStringTrimLeft("UTF8_BINARY_LCASE", "xa世ax", "x", "a世ax");
+
+    assertStringTrimRight("UTF8_BINARY_LCASE", "asd", null, "asd");
+    assertStringTrimRight("UTF8_BINARY_LCASE", "  asd  ", null, "  asd");
+    assertStringTrimRight("UTF8_BINARY_LCASE", " a世a ", null, " a世a");
+    assertStringTrimRight("UTF8_BINARY_LCASE", "asd", "x", "asd");
+    assertStringTrimRight("UTF8_BINARY_LCASE", "xxasdxx", "x", "xxasd");
+    assertStringTrimRight("UTF8_BINARY_LCASE", "xa世ax", "x", "xa世a");
+
+    assertStringTrim("UTF8_BINARY_LCASE", "asd", null, "asd");
+    assertStringTrim("UTF8_BINARY_LCASE", "  asd  ", null, "asd");
+    assertStringTrim("UTF8_BINARY_LCASE", " a世a ", null, "a世a");
+    assertStringTrim("UTF8_BINARY_LCASE", "asd", "x", "asd");
+    assertStringTrim("UTF8_BINARY_LCASE", "xxasdxx", "x", "asd");
+    assertStringTrim("UTF8_BINARY_LCASE", "xa世ax", "x", "a世a");
+
+    assertStringTrimLeft("UNICODE", "asd", null, "asd");
+    assertStringTrimLeft("UNICODE", "  asd  ", null, "asd  ");
+    assertStringTrimLeft("UNICODE", " a世a ", null, "a世a ");
+    assertStringTrimLeft("UNICODE", "asd", "x", "asd");
+    assertStringTrimLeft("UNICODE", "xxasdxx", "x", "asdxx");
+    assertStringTrimLeft("UNICODE", "xa世ax", "x", "a世ax");
+
+    assertStringTrimRight("UNICODE", "asd", null, "asd");
+    assertStringTrimRight("UNICODE", "  asd  ", null, "  asd");
+    assertStringTrimRight("UNICODE", " a世a ", null, " a世a");
+    assertStringTrimRight("UNICODE", "asd", "x", "asd");
+    assertStringTrimRight("UNICODE", "xxasdxx", "x", "xxasd");
+    assertStringTrimRight("UNICODE", "xa世ax", "x", "xa世a");
+
+    // Test cases where trimString has more than one character
+    assertStringTrim("UTF8_BINARY", "ddsXXXaa", "asd", "XXX");
+    assertStringTrimLeft("UTF8_BINARY", "ddsXXXaa", "asd", "XXXaa");
+    assertStringTrimRight("UTF8_BINARY", "ddsXXXaa", "asd", "ddsXXX");
+
+    assertStringTrim("UTF8_BINARY_LCASE", "ddsXXXaa", "asd", "XXX");
+    assertStringTrimLeft("UTF8_BINARY_LCASE", "ddsXXXaa", "asd", "XXXaa");
+    assertStringTrimRight("UTF8_BINARY_LCASE", "ddsXXXaa", "asd", "ddsXXX");
+
+    assertStringTrim("UNICODE", "ddsXXXaa", "asd", "XXX");
+    assertStringTrimLeft("UNICODE", "ddsXXXaa", "asd", "XXXaa");
+    assertStringTrimRight("UNICODE", "ddsXXXaa", "asd", "ddsXXX");
+
+    // Test cases specific to collation type
+    // uppercase trim, lowercase src
+    assertStringTrim("UTF8_BINARY", "asd", "A", "asd");
+    assertStringTrim("UTF8_BINARY_LCASE", "asd", "A", "sd");
+    assertStringTrim("UNICODE", "asd", "A", "asd");
+    assertStringTrim("UNICODE_CI", "asd", "A", "sd");
+
+    // lowercase trim, uppercase src
+    assertStringTrim("UTF8_BINARY", "ASD", "a", "ASD");
+    assertStringTrim("UTF8_BINARY_LCASE", "ASD", "a", "SD");
+    assertStringTrim("UNICODE", "ASD", "a", "ASD");
+    assertStringTrim("UNICODE_CI", "ASD", "a", "SD");
+
+    // uppercase and lowercase chars of different byte-length (utf8)
+    assertStringTrim("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ");
+    assertStringTrimLeft("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ");
+    assertStringTrimRight("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ");
+
+    assertStringTrim("UTF8_BINARY_LCASE", "ẞaaaẞ", "ß", "aaa");
+    assertStringTrimLeft("UTF8_BINARY_LCASE", "ẞaaaẞ", "ß", "aaaẞ");
+    assertStringTrimRight("UTF8_BINARY_LCASE", "ẞaaaẞ", "ß", "ẞaaa");
+
+    assertStringTrim("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ");
+    assertStringTrimLeft("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ");
+    assertStringTrimRight("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ");
+
+    assertStringTrim("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß");
+    assertStringTrimLeft("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß");
+    assertStringTrimRight("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß");
+
+    assertStringTrim("UTF8_BINARY_LCASE", "ßaaaß", "ẞ", "aaa");
+    assertStringTrimLeft("UTF8_BINARY_LCASE", "ßaaaß", "ẞ", "aaaß");
+    assertStringTrimRight("UTF8_BINARY_LCASE", "ßaaaß", "ẞ", "ßaaa");
+
+    assertStringTrim("UNICODE", "ßaaaß", "ẞ", "ßaaaß");
+    assertStringTrimLeft("UNICODE", "ßaaaß", "ẞ", "ßaaaß");
+    assertStringTrimRight("UNICODE", "ßaaaß", "ẞ", "ßaaaß");
+
+    // different byte-length (utf8) chars trimmed
+    assertStringTrim("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "aaa");
+    assertStringTrimLeft("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "aaaẞ");
+    assertStringTrimRight("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "Ëaaa");
+
+    assertStringTrim("UTF8_BINARY_LCASE", "Ëaaaẞ", "Ëẞ", "aaa");
+    assertStringTrimLeft("UTF8_BINARY_LCASE", "Ëaaaẞ", "Ëẞ", "aaaẞ");
+    assertStringTrimRight("UTF8_BINARY_LCASE", "Ëaaaẞ", "Ëẞ", "Ëaaa");
+
+    assertStringTrim("UNICODE", "Ëaaaẞ", "Ëẞ", "aaa");
+    assertStringTrimLeft("UNICODE", "Ëaaaẞ", "Ëẞ", "aaaẞ");
+    assertStringTrimRight("UNICODE", "Ëaaaẞ", "Ëẞ", "Ëaaa");
+  }
+
   // TODO: Test more collation-aware string expressions.
 
   /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
index 44349384187e..a50dad7c8cdb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala
@@ -74,7 +74,7 @@ object CollationTypeCasts extends TypeCoercionRule {
     case otherExpr @ (
       _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: 
Greatest | _: Least |
       _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: 
StringReplace |
-      _: StringTranslate) =>
+      _: StringTranslate | _: StringTrim | _: StringTrimLeft | _: 
StringTrimRight) =>
       val newChildren = collateToSingleType(otherExpr.children)
       otherExpr.withNewChildren(newChildren)
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 0bdd7930b0bf..09ec501311ad 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -37,7 +37,7 @@ import 
org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO
 import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, 
CollationSupport, GenericArrayData, TypeUtils}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.types.{AbstractArrayType, 
StringTypeAnyCollation}
+import org.apache.spark.sql.internal.types.{AbstractArrayType, 
StringTypeAnyCollation, StringTypeBinaryLcase}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.UTF8StringBuilder
 import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -1020,8 +1020,10 @@ trait String2TrimExpression extends Expression with 
ImplicitCastInputTypes {
   protected def direction: String
 
   override def children: Seq[Expression] = srcStr +: trimStr.toSeq
-  override def dataType: DataType = StringType
-  override def inputTypes: Seq[AbstractDataType] = 
Seq.fill(children.size)(StringType)
+  override def dataType: DataType = srcStr.dataType
+  override def inputTypes: Seq[AbstractDataType] = 
Seq.fill(children.size)(StringTypeBinaryLcase)
+
+  final lazy val collationId: Int = 
srcStr.dataType.asInstanceOf[StringType].collationId
 
   override def nullable: Boolean = children.exists(_.nullable)
   override def foldable: Boolean = children.forall(_.foldable)
@@ -1040,13 +1042,19 @@ trait String2TrimExpression extends Expression with 
ImplicitCastInputTypes {
     }
   }
 
-  protected val trimMethod: String
-
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val evals = children.map(_.genCode(ctx))
-    val srcString = evals(0)
+    val srcString = evals.head
 
     if (evals.length == 1) {
+      val stringTrimCode: String = this match {
+        case _: StringTrim =>
+          CollationSupport.StringTrim.genCode(srcString.value, collationId)
+        case _: StringTrimLeft =>
+          CollationSupport.StringTrimLeft.genCode(srcString.value, collationId)
+        case _: StringTrimRight =>
+          CollationSupport.StringTrimRight.genCode(srcString.value, 
collationId)
+      }
       ev.copy(code = code"""
          |${srcString.code}
          |boolean ${ev.isNull} = false;
@@ -1054,10 +1062,18 @@ trait String2TrimExpression extends Expression with 
ImplicitCastInputTypes {
          |if (${srcString.isNull}) {
          |  ${ev.isNull} = true;
          |} else {
-         |  ${ev.value} = ${srcString.value}.$trimMethod();
+         |  ${ev.value} = $stringTrimCode;
          |}""".stripMargin)
     } else {
       val trimString = evals(1)
+      val stringTrimCode: String = this match {
+        case _: StringTrim =>
+          CollationSupport.StringTrim.genCode(srcString.value, 
trimString.value, collationId)
+        case _: StringTrimLeft =>
+          CollationSupport.StringTrimLeft.genCode(srcString.value, 
trimString.value, collationId)
+        case _: StringTrimRight =>
+          CollationSupport.StringTrimRight.genCode(srcString.value, 
trimString.value, collationId)
+      }
       ev.copy(code = code"""
          |${srcString.code}
          |boolean ${ev.isNull} = false;
@@ -1069,7 +1085,7 @@ trait String2TrimExpression extends Expression with 
ImplicitCastInputTypes {
          |  if (${trimString.isNull}) {
          |    ${ev.isNull} = true;
          |  } else {
-         |    ${ev.value} = 
${srcString.value}.$trimMethod(${trimString.value});
+         |    ${ev.value} = $stringTrimCode;
          |  }
          |}""".stripMargin)
     }
@@ -1162,12 +1178,11 @@ case class StringTrim(srcStr: Expression, trimStr: 
Option[Expression] = None)
 
   override protected def direction: String = "BOTH"
 
-  override def doEval(srcString: UTF8String): UTF8String = srcString.trim()
+  override def doEval(srcString: UTF8String): UTF8String =
+    CollationSupport.StringTrim.exec(srcString, collationId)
 
   override def doEval(srcString: UTF8String, trimString: UTF8String): 
UTF8String =
-    srcString.trim(trimString)
-
-  override val trimMethod: String = "trim"
+    CollationSupport.StringTrim.exec(srcString, trimString, collationId)
 
   override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): Expression =
     copy(
@@ -1270,12 +1285,11 @@ case class StringTrimLeft(srcStr: Expression, trimStr: 
Option[Expression] = None
 
   override protected def direction: String = "LEADING"
 
-  override def doEval(srcString: UTF8String): UTF8String = srcString.trimLeft()
+  override def doEval(srcString: UTF8String): UTF8String =
+    CollationSupport.StringTrimLeft.exec(srcString, collationId)
 
   override def doEval(srcString: UTF8String, trimString: UTF8String): 
UTF8String =
-    srcString.trimLeft(trimString)
-
-  override val trimMethod: String = "trimLeft"
+    CollationSupport.StringTrimLeft.exec(srcString, trimString, collationId)
 
   override protected def withNewChildrenInternal(
       newChildren: IndexedSeq[Expression]): StringTrimLeft =
@@ -1331,12 +1345,11 @@ case class StringTrimRight(srcStr: Expression, trimStr: 
Option[Expression] = Non
 
   override protected def direction: String = "TRAILING"
 
-  override def doEval(srcString: UTF8String): UTF8String = 
srcString.trimRight()
+  override def doEval(srcString: UTF8String): UTF8String =
+    CollationSupport.StringTrimRight.exec(srcString, collationId)
 
   override def doEval(srcString: UTF8String, trimString: UTF8String): 
UTF8String =
-    srcString.trimRight(trimString)
-
-  override val trimMethod: String = "trimRight"
+    CollationSupport.StringTrimRight.exec(srcString, trimString, collationId)
 
   override protected def withNewChildrenInternal(
       newChildren: IndexedSeq[Expression]): StringTrimRight =
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index b9a4fecd0465..9cc123b708af 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, 
Literal, StringTrim, StringTrimLeft, StringTrimRight}
 import org.apache.spark.sql.catalyst.util.CollationFactory
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
@@ -26,7 +27,8 @@ import org.apache.spark.sql.types.{ArrayType, BinaryType, 
BooleanType, DataType,
 // scalastyle:off nonascii
 class CollationStringExpressionsSuite
   extends QueryTest
-  with SharedSparkSession {
+  with SharedSparkSession
+  with ExpressionEvalHelper {
 
   test("Support ConcatWs string expression with collation") {
     // Supported collations
@@ -800,6 +802,163 @@ class CollationStringExpressionsSuite
     assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
   }
 
+  test("StringTrim* functions - unit tests for both paths (codegen and eval)") 
{
+    // Without trimString param.
+    checkEvaluation(StringTrim(Literal.create( "  asd  ", 
StringType("UTF8_BINARY"))), "asd")
+    checkEvaluation(
+      StringTrimLeft(Literal.create("  asd  ", 
StringType("UTF8_BINARY_LCASE"))), "asd  ")
+    checkEvaluation(StringTrimRight(Literal.create("  asd  ", 
StringType("UNICODE"))), "  asd")
+
+    // With trimString param.
+    checkEvaluation(
+      StringTrim(
+        Literal.create("  asd  ", StringType("UTF8_BINARY")),
+        Literal.create(" ", StringType("UTF8_BINARY"))),
+      "asd")
+    checkEvaluation(
+      StringTrimLeft(
+        Literal.create("  asd  ", StringType("UTF8_BINARY_LCASE")),
+        Literal.create(" ", StringType("UTF8_BINARY_LCASE"))),
+      "asd  ")
+    checkEvaluation(
+      StringTrimRight(
+        Literal.create("  asd  ", StringType("UNICODE")),
+        Literal.create(" ", StringType("UNICODE"))),
+      "  asd")
+
+    checkEvaluation(
+      StringTrim(
+        Literal.create("xxasdxx", StringType("UTF8_BINARY")),
+        Literal.create("x", StringType("UTF8_BINARY"))),
+      "asd")
+    checkEvaluation(
+      StringTrimLeft(
+        Literal.create("xxasdxx", StringType("UTF8_BINARY_LCASE")),
+        Literal.create("x", StringType("UTF8_BINARY_LCASE"))),
+      "asdxx")
+    checkEvaluation(
+      StringTrimRight(
+        Literal.create("xxasdxx", StringType("UNICODE")),
+        Literal.create("x", StringType("UNICODE"))),
+      "xxasd")
+  }
+
+  test("StringTrim* functions - E2E tests") {
+    case class StringTrimTestCase(
+      collation: String,
+      trimFunc: String,
+      sourceString: String,
+      hasTrimString: Boolean,
+      trimString: String,
+      expectedResultString: String)
+
+    val testCases = Seq(
+      StringTrimTestCase("UTF8_BINARY", "TRIM", "  asd  ", false, null, "asd"),
+      StringTrimTestCase("UTF8_BINARY", "BTRIM", "  asd  ", true, null, null),
+      StringTrimTestCase("UTF8_BINARY", "LTRIM", "xxasdxx", true, "x", 
"asdxx"),
+      StringTrimTestCase("UTF8_BINARY", "RTRIM", "xxasdxx", true, "x", 
"xxasd"),
+
+      StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", "  asd  ", true, null, 
null),
+      StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "xxasdxx", true, "x", 
"asd"),
+      StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "xxasdxx", true, "x", 
"asdxx"),
+      StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", "  asd  ", false, null, 
"  asd"),
+
+      StringTrimTestCase("UNICODE", "TRIM", "xxasdxx", true, "x", "asd"),
+      StringTrimTestCase("UNICODE", "BTRIM", "xxasdxx", true, "x", "asd"),
+      StringTrimTestCase("UNICODE", "LTRIM", "  asd  ", false, null, "asd  "),
+      StringTrimTestCase("UNICODE", "RTRIM", "  asd  ", true, null, null)
+
+      // Other more complex cases can be found in unit tests in 
CollationSupportSuite.java.
+    )
+
+    testCases.foreach(testCase => {
+      var df: DataFrame = null
+
+      if (testCase.trimFunc.equalsIgnoreCase("BTRIM")) {
+        // BTRIM has arguments in (srcStr, trimStr) order
+        df = sql(s"SELECT ${testCase.trimFunc}(" +
+          s"COLLATE('${testCase.sourceString}', '${testCase.collation}')" +
+            (if (!testCase.hasTrimString) ""
+            else if (testCase.trimString == null) ", null"
+            else s", '${testCase.trimString}'") +
+          ")")
+      }
+      else {
+        // While other functions have arguments in (trimStr, srcStr) order
+        df = sql(s"SELECT ${testCase.trimFunc}(" +
+            (if (!testCase.hasTrimString) ""
+            else if (testCase.trimString == null) "null, "
+            else s"'${testCase.trimString}', ") +
+          s"COLLATE('${testCase.sourceString}', '${testCase.collation}')" +
+          ")")
+      }
+
+      checkAnswer(df = df, expectedAnswer = Row(testCase.expectedResultString))
+    })
+  }
+
+  test("StringTrim* functions - implicit collations") {
+    checkAnswer(
+      df = sql("SELECT TRIM(COLLATE('x', 'UTF8_BINARY'), COLLATE('xax', 
'UTF8_BINARY'))"),
+      expectedAnswer = Row("a"))
+    checkAnswer(
+      df = sql("SELECT BTRIM(COLLATE('xax', 'UTF8_BINARY_LCASE'), "
+        + "COLLATE('x', 'UTF8_BINARY_LCASE'))"),
+      expectedAnswer = Row("a"))
+    checkAnswer(
+      df = sql("SELECT LTRIM(COLLATE('x', 'UNICODE'), COLLATE('xax', 
'UNICODE'))"),
+      expectedAnswer = Row("ax"))
+
+    checkAnswer(
+      df = sql("SELECT RTRIM('x', COLLATE('xax', 'UTF8_BINARY'))"),
+      expectedAnswer = Row("xa"))
+    checkAnswer(
+      df = sql("SELECT TRIM('x', COLLATE('xax', 'UTF8_BINARY_LCASE'))"),
+      expectedAnswer = Row("a"))
+    checkAnswer(
+      df = sql("SELECT BTRIM('xax', COLLATE('x', 'UNICODE'))"),
+      expectedAnswer = Row("a"))
+
+    checkAnswer(
+      df = sql("SELECT LTRIM(COLLATE('x', 'UTF8_BINARY'), 'xax')"),
+      expectedAnswer = Row("ax"))
+    checkAnswer(
+      df = sql("SELECT RTRIM(COLLATE('x', 'UTF8_BINARY_LCASE'), 'xax')"),
+      expectedAnswer = Row("xa"))
+    checkAnswer(
+      df = sql("SELECT TRIM(COLLATE('x', 'UNICODE'), 'xax')"),
+      expectedAnswer = Row("a"))
+  }
+
+  test("StringTrim* functions - collation type mismatch") {
+    List("TRIM", "LTRIM", "RTRIM").foreach(func => {
+      val collationMismatch = intercept[AnalysisException] {
+        sql("SELECT " + func + "(COLLATE('x', 'UTF8_BINARY_LCASE'), "
+          + "COLLATE('xxaaaxx', 'UNICODE'))")
+      }
+      assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+    })
+
+    val collationMismatch = intercept[AnalysisException] {
+      sql("SELECT BTRIM(COLLATE('xxaaaxx', 'UNICODE'), COLLATE('x', 
'UTF8_BINARY_LCASE'))")
+    }
+    assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+  }
+
+  test("StringTrim* functions - unsupported collation types") {
+    List("TRIM", "LTRIM", "RTRIM").foreach(func => {
+      val collationMismatch = intercept[AnalysisException] {
+        sql("SELECT " + func + "(COLLATE('x', 'UNICODE_CI'), 
COLLATE('xxaaaxx', 'UNICODE_CI'))")
+      }
+      assert(collationMismatch.getErrorClass === 
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+    })
+
+    val collationMismatch = intercept[AnalysisException] {
+      sql("SELECT BTRIM(COLLATE('xxaaaxx', 'UNICODE_CI'), COLLATE('x', 
'UNICODE_CI'))")
+    }
+    assert(collationMismatch.getErrorClass === 
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+  }
+
   // TODO: Add more tests for other string expressions
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to