This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 21333f8c1fc0 [SPARK-47409][SQL] Add support for collation for StringTrim type of functions/expressions (for UTF8_BINARY & LCASE) 21333f8c1fc0 is described below commit 21333f8c1fc01756e6708ad6ccf21f585fcb881d Author: David Milicevic <david.milice...@databricks.com> AuthorDate: Thu May 9 23:05:20 2024 +0800 [SPARK-47409][SQL] Add support for collation for StringTrim type of functions/expressions (for UTF8_BINARY & LCASE) Recreating [original PR](https://github.com/apache/spark/pull/45749) because code has been reorganized in [this PR](https://github.com/apache/spark/pull/45978). ### What changes were proposed in this pull request? This PR is created to add support for collations to StringTrim family of functions/expressions, specifically: - `StringTrim` - `StringTrimBoth` - `StringTrimLeft` - `StringTrimRight` Changes: - `CollationSupport.java` - Add new `StringTrim`, `StringTrimLeft` and `StringTrimRight` classes with corresponding logic. - `CollationAwareUTF8String` - add new `trim`, `trimLeft` and `trimRight` methods that actually implement trim logic. - `UTF8String.java` - expose some of the methods publicly. - `stringExpressions.scala` - Change input types. - Change eval and code gen logic. - `CollationTypeCasts.scala` - add `StringTrim*` expressions to `CollationTypeCasts` rules. ### Why are the changes needed? We are incrementally adding collation support to a built-in string functions in Spark. ### Does this PR introduce _any_ user-facing change? Yes: - User should now be able to use non-default collations in string trim functions. ### How was this patch tested? Already existing tests + new unit/e2e tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46206 from davidm-db/string-trim-functions. Authored-by: David Milicevic <david.milice...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../catalyst/util/CollationAwareUTF8String.java | 470 ++++++++++++++++++ .../spark/sql/catalyst/util/CollationSupport.java | 534 ++++++++------------- .../org/apache/spark/unsafe/types/UTF8String.java | 2 +- .../spark/unsafe/types/CollationSupportSuite.java | 193 ++++++++ .../sql/catalyst/analysis/CollationTypeCasts.scala | 2 +- .../catalyst/expressions/stringExpressions.scala | 53 +- .../sql/CollationStringExpressionsSuite.scala | 161 ++++++- 7 files changed, 1054 insertions(+), 361 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java new file mode 100644 index 000000000000..ee0d611d7e65 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -0,0 +1,470 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.util; + +import com.ibm.icu.lang.UCharacter; +import com.ibm.icu.text.BreakIterator; +import com.ibm.icu.text.StringSearch; +import com.ibm.icu.util.ULocale; + +import org.apache.spark.unsafe.UTF8StringBuilder; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; +import static org.apache.spark.unsafe.Platform.copyMemory; + +import java.util.HashMap; +import java.util.Map; + +/** + * Utility class for collation-aware UTF8String operations. + */ +public class CollationAwareUTF8String { + public static UTF8String replace(final UTF8String src, final UTF8String search, + final UTF8String replace, final int collationId) { + // This collation aware implementation is based on existing implementation on UTF8String + if (src.numBytes() == 0 || search.numBytes() == 0) { + return src; + } + + StringSearch stringSearch = CollationFactory.getStringSearch(src, search, collationId); + + // Find the first occurrence of the search string. + int end = stringSearch.next(); + if (end == StringSearch.DONE) { + // Search string was not found, so string is unchanged. + return src; + } + + // Initialize byte positions + int c = 0; + int byteStart = 0; // position in byte + int byteEnd = 0; // position in byte + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + + // At least one match was found. Estimate space needed for result. + // The 16x multiplier here is chosen to match commons-lang3's implementation. + int increase = Math.max(0, Math.abs(replace.numBytes() - search.numBytes())) * 16; + final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); + while (end != StringSearch.DONE) { + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); + buf.append(replace); + + // Move byteStart to the beginning of the current match + byteStart = byteEnd; + int cs = c; + // Move cs to the end of the current match + // This is necessary because the search string may contain 'multi-character' characters + while (byteStart < src.numBytes() && cs < c + stringSearch.getMatchLength()) { + byteStart += UTF8String.numBytesForFirstByte(src.getByte(byteStart)); + cs += 1; + } + // Go to next match + end = stringSearch.next(); + // Update byte positions + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + } + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, + src.numBytes() - byteStart); + return buf.build(); + } + + public static UTF8String lowercaseReplace(final UTF8String src, final UTF8String search, + final UTF8String replace) { + if (src.numBytes() == 0 || search.numBytes() == 0) { + return src; + } + UTF8String lowercaseString = src.toLowerCase(); + UTF8String lowercaseSearch = search.toLowerCase(); + + int start = 0; + int end = lowercaseString.indexOf(lowercaseSearch, 0); + if (end == -1) { + // Search string was not found, so string is unchanged. + return src; + } + + // Initialize byte positions + int c = 0; + int byteStart = 0; // position in byte + int byteEnd = 0; // position in byte + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + + // At least one match was found. Estimate space needed for result. + // The 16x multiplier here is chosen to match commons-lang3's implementation. + int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16; + final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); + while (end != -1) { + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); + buf.append(replace); + // Update character positions + start = end + lowercaseSearch.numChars(); + end = lowercaseString.indexOf(lowercaseSearch, start); + // Update byte positions + byteStart = byteEnd + search.numBytes(); + while (byteEnd < src.numBytes() && c < end) { + byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); + c += 1; + } + } + buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, + src.numBytes() - byteStart); + return buf.build(); + } + + public static String toUpperCase(final String target, final int collationId) { + ULocale locale = CollationFactory.fetchCollation(collationId) + .collator.getLocale(ULocale.ACTUAL_LOCALE); + return UCharacter.toUpperCase(locale, target); + } + + public static String toLowerCase(final String target, final int collationId) { + ULocale locale = CollationFactory.fetchCollation(collationId) + .collator.getLocale(ULocale.ACTUAL_LOCALE); + return UCharacter.toLowerCase(locale, target); + } + + public static String toTitleCase(final String target, final int collationId) { + ULocale locale = CollationFactory.fetchCollation(collationId) + .collator.getLocale(ULocale.ACTUAL_LOCALE); + return UCharacter.toTitleCase(locale, target, BreakIterator.getWordInstance(locale)); + } + + public static int findInSet(final UTF8String match, final UTF8String set, int collationId) { + if (match.contains(UTF8String.fromString(","))) { + return 0; + } + + String setString = set.toString(); + StringSearch stringSearch = CollationFactory.getStringSearch(setString, match.toString(), + collationId); + + int wordStart = 0; + while ((wordStart = stringSearch.next()) != StringSearch.DONE) { + boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ','; + boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length() + || setString.charAt(wordStart + stringSearch.getMatchLength()) == ','; + + if (isValidStart && isValidEnd) { + int pos = 0; + for (int i = 0; i < setString.length() && i < wordStart; i++) { + if (setString.charAt(i) == ',') { + pos++; + } + } + + return pos + 1; + } + } + + return 0; + } + + public static int indexOf(final UTF8String target, final UTF8String pattern, + final int start, final int collationId) { + if (pattern.numBytes() == 0) { + return 0; + } + + StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); + stringSearch.setIndex(start); + + return stringSearch.next(); + } + + public static int find(UTF8String target, UTF8String pattern, int start, + int collationId) { + assert (pattern.numBytes() > 0); + + StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); + // Set search start position (start from character at start position) + stringSearch.setIndex(target.bytePosToChar(start)); + + // Return either the byte position or -1 if not found + return target.charPosToByte(stringSearch.next()); + } + + public static UTF8String subStringIndex(final UTF8String string, final UTF8String delimiter, + int count, final int collationId) { + if (delimiter.numBytes() == 0 || count == 0 || string.numBytes() == 0) { + return UTF8String.EMPTY_UTF8; + } + if (count > 0) { + int idx = -1; + while (count > 0) { + idx = find(string, delimiter, idx + 1, collationId); + if (idx >= 0) { + count --; + } else { + // can not find enough delim + return string; + } + } + if (idx == 0) { + return UTF8String.EMPTY_UTF8; + } + byte[] bytes = new byte[idx]; + copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx); + return UTF8String.fromBytes(bytes); + + } else { + count = -count; + + StringSearch stringSearch = CollationFactory + .getStringSearch(string, delimiter, collationId); + + int start = string.numChars() - 1; + int lastMatchLength = 0; + int prevStart = -1; + while (count > 0) { + stringSearch.reset(); + prevStart = -1; + int matchStart = stringSearch.next(); + lastMatchLength = stringSearch.getMatchLength(); + while (matchStart <= start) { + if (matchStart != StringSearch.DONE) { + // Found a match, update the start position + prevStart = matchStart; + matchStart = stringSearch.next(); + } else { + break; + } + } + + if (prevStart == -1) { + // can not find enough delim + return string; + } else { + start = prevStart - 1; + count--; + } + } + + int resultStart = prevStart + lastMatchLength; + if (resultStart == string.numChars()) { + return UTF8String.EMPTY_UTF8; + } + + return string.substring(resultStart, string.numChars()); + } + } + + public static UTF8String lowercaseSubStringIndex(final UTF8String string, + final UTF8String delimiter, int count) { + if (delimiter.numBytes() == 0 || count == 0) { + return UTF8String.EMPTY_UTF8; + } + + UTF8String lowercaseString = string.toLowerCase(); + UTF8String lowercaseDelimiter = delimiter.toLowerCase(); + + if (count > 0) { + int idx = -1; + while (count > 0) { + idx = lowercaseString.find(lowercaseDelimiter, idx + 1); + if (idx >= 0) { + count--; + } else { + // can not find enough delim + return string; + } + } + if (idx == 0) { + return UTF8String.EMPTY_UTF8; + } + byte[] bytes = new byte[idx]; + copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx); + return UTF8String.fromBytes(bytes); + + } else { + int idx = string.numBytes() - delimiter.numBytes() + 1; + count = -count; + while (count > 0) { + idx = lowercaseString.rfind(lowercaseDelimiter, idx - 1); + if (idx >= 0) { + count--; + } else { + // can not find enough delim + return string; + } + } + if (idx + delimiter.numBytes() == string.numBytes()) { + return UTF8String.EMPTY_UTF8; + } + int size = string.numBytes() - delimiter.numBytes() - idx; + byte[] bytes = new byte[size]; + copyMemory(string.getBaseObject(), string.getBaseOffset() + idx + delimiter.numBytes(), + bytes, BYTE_ARRAY_OFFSET, size); + return UTF8String.fromBytes(bytes); + } + } + + public static Map<String, String> getCollationAwareDict(UTF8String string, + Map<String, String> dict, int collationId) { + String srcStr = string.toString(); + + Map<String, String> collationAwareDict = new HashMap<>(); + for (String key : dict.keySet()) { + StringSearch stringSearch = + CollationFactory.getStringSearch(string, UTF8String.fromString(key), collationId); + + int pos = 0; + while ((pos = stringSearch.next()) != StringSearch.DONE) { + int codePoint = srcStr.codePointAt(pos); + int charCount = Character.charCount(codePoint); + String newKey = srcStr.substring(pos, pos + charCount); + + boolean exists = false; + for (String existingKey : collationAwareDict.keySet()) { + if (stringSearch.getCollator().compare(existingKey, newKey) == 0) { + collationAwareDict.put(newKey, collationAwareDict.get(existingKey)); + exists = true; + break; + } + } + + if (!exists) { + collationAwareDict.put(newKey, dict.get(key)); + } + } + } + + return collationAwareDict; + } + + public static UTF8String lowercaseTrim( + final UTF8String srcString, + final UTF8String trimString) { + // Matching UTF8String behavior for null `trimString`. + if (trimString == null) { + return null; + } + + UTF8String leftTrimmed = lowercaseTrimLeft(srcString, trimString); + return lowercaseTrimRight(leftTrimmed, trimString); + } + + public static UTF8String lowercaseTrimLeft( + final UTF8String srcString, + final UTF8String trimString) { + // Matching UTF8String behavior for null `trimString`. + if (trimString == null) { + return null; + } + + // The searching byte position in the srcString. + int searchIdx = 0; + // The byte position of a first non-matching character in the srcString. + int trimByteIdx = 0; + // Number of bytes in srcString. + int numBytes = srcString.numBytes(); + // Convert trimString to lowercase, so it can be searched properly. + UTF8String lowercaseTrimString = trimString.toLowerCase(); + + while (searchIdx < numBytes) { + UTF8String searchChar = srcString.copyUTF8String( + searchIdx, + searchIdx + UTF8String.numBytesForFirstByte(srcString.getByte(searchIdx)) - 1); + int searchCharBytes = searchChar.numBytes(); + + // Try to find the matching for the searchChar in the trimString. + if (lowercaseTrimString.find(searchChar.toLowerCase(), 0) >= 0) { + trimByteIdx += searchCharBytes; + searchIdx += searchCharBytes; + } else { + // No matching, exit the search. + break; + } + } + + if (searchIdx == 0) { + // Nothing trimmed - return original string (not converted to lowercase). + return srcString; + } + if (trimByteIdx >= numBytes) { + // Everything trimmed. + return UTF8String.EMPTY_UTF8; + } + return srcString.copyUTF8String(trimByteIdx, numBytes - 1); + } + + public static UTF8String lowercaseTrimRight( + final UTF8String srcString, + final UTF8String trimString) { + // Matching UTF8String behavior for null `trimString`. + if (trimString == null) { + return null; + } + + // Number of bytes iterated from the srcString. + int byteIdx = 0; + // Number of characters iterated from the srcString. + int numChars = 0; + // Number of bytes in srcString. + int numBytes = srcString.numBytes(); + // Array of character length for the srcString. + int[] stringCharLen = new int[numBytes]; + // Array of the first byte position for each character in the srcString. + int[] stringCharPos = new int[numBytes]; + // Convert trimString to lowercase, so it can be searched properly. + UTF8String lowercaseTrimString = trimString.toLowerCase(); + + // Build the position and length array. + while (byteIdx < numBytes) { + stringCharPos[numChars] = byteIdx; + stringCharLen[numChars] = UTF8String.numBytesForFirstByte(srcString.getByte(byteIdx)); + byteIdx += stringCharLen[numChars]; + numChars++; + } + + // Index trimEnd points to the first no matching byte position from the right side of + // the source string. + int trimByteIdx = numBytes - 1; + + while (numChars > 0) { + UTF8String searchChar = srcString.copyUTF8String( + stringCharPos[numChars - 1], + stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); + + if(lowercaseTrimString.find(searchChar.toLowerCase(), 0) >= 0) { + trimByteIdx -= stringCharLen[numChars - 1]; + numChars--; + } else { + break; + } + } + + if (trimByteIdx == numBytes - 1) { + // Nothing trimmed. + return srcString; + } + if (trimByteIdx < 0) { + // Everything trimmed. + return UTF8String.EMPTY_UTF8; + } + return srcString.copyUTF8String(0, trimByteIdx); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index b77671cee90b..bea3dc08b448 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -16,23 +16,15 @@ */ package org.apache.spark.sql.catalyst.util; -import com.ibm.icu.lang.UCharacter; -import com.ibm.icu.text.BreakIterator; import com.ibm.icu.text.StringSearch; -import com.ibm.icu.util.ULocale; -import org.apache.spark.unsafe.UTF8StringBuilder; import org.apache.spark.unsafe.types.UTF8String; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.regex.Pattern; -import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; -import static org.apache.spark.unsafe.Platform.copyMemory; - /** * Static entry point for collation-aware expressions (StringExpressions, RegexpExpressions, and * other expressions that require custom collation support), as well as private utility methods for @@ -441,7 +433,7 @@ public final class CollationSupport { return string.toLowerCase().indexOf(substring.toLowerCase(), start); } public static int execICU(final UTF8String string, final UTF8String substring, final int start, - final int collationId) { + final int collationId) { return CollationAwareUTF8String.indexOf(string, substring, start, collationId); } } @@ -535,6 +527,201 @@ public final class CollationSupport { } } + public static class StringTrim { + public static UTF8String exec( + final UTF8String srcString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(srcString); + } else { + return execLowercase(srcString); + } + } + public static UTF8String exec( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(srcString, trimString); + } else { + return execLowercase(srcString, trimString); + } + } + public static String genCode( + final String srcString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringTrim.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s)", srcString); + } { + return String.format(expr + "Lowercase(%s)", srcString); + } + } + public static String genCode( + final String srcString, + final String trimString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringTrim.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", srcString, trimString); + } else { + return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); + } + } + public static UTF8String execBinary( + final UTF8String srcString) { + return srcString.trim(); + } + public static UTF8String execBinary( + final UTF8String srcString, + final UTF8String trimString) { + return srcString.trim(trimString); + } + public static UTF8String execLowercase( + final UTF8String srcString) { + return srcString.trim(); + } + public static UTF8String execLowercase( + final UTF8String srcString, + final UTF8String trimString) { + return CollationAwareUTF8String.lowercaseTrim(srcString, trimString); + } + } + + public static class StringTrimLeft { + public static UTF8String exec( + final UTF8String srcString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(srcString); + } else { + return execLowercase(srcString); + } + } + public static UTF8String exec( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(srcString, trimString); + } else { + return execLowercase(srcString, trimString); + } + } + public static String genCode( + final String srcString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringTrimLeft.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s)", srcString); + } else { + return String.format(expr + "Lowercase(%s)", srcString); + } + } + public static String genCode( + final String srcString, + final String trimString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringTrimLeft.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", srcString, trimString); + } else { + return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); + } + } + public static UTF8String execBinary( + final UTF8String srcString) { + return srcString.trimLeft(); + } + public static UTF8String execBinary( + final UTF8String srcString, + final UTF8String trimString) { + return srcString.trimLeft(trimString); + } + public static UTF8String execLowercase( + final UTF8String srcString) { + return srcString.trimLeft(); + } + public static UTF8String execLowercase( + final UTF8String srcString, + final UTF8String trimString) { + return CollationAwareUTF8String.lowercaseTrimLeft(srcString, trimString); + } + } + + public static class StringTrimRight { + public static UTF8String exec( + final UTF8String srcString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(srcString); + } else { + return execLowercase(srcString); + } + } + public static UTF8String exec( + final UTF8String srcString, + final UTF8String trimString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(srcString, trimString); + } else { + return execLowercase(srcString, trimString); + } + } + public static String genCode( + final String srcString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringTrimRight.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s)", srcString); + } else { + return String.format(expr + "Lowercase(%s)", srcString); + } + } + public static String genCode( + final String srcString, + final String trimString, + final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringTrimRight.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", srcString, trimString); + } else { + return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); + } + } + public static UTF8String execBinary( + final UTF8String srcString) { + return srcString.trimRight(); + } + public static UTF8String execBinary( + final UTF8String srcString, + final UTF8String trimString) { + return srcString.trimRight(trimString); + } + public static UTF8String execLowercase( + final UTF8String srcString) { + return srcString.trimRight(); + } + public static UTF8String execLowercase( + final UTF8String srcString, + final UTF8String trimString) { + return CollationAwareUTF8String.lowercaseTrimRight(srcString, trimString); + } + } + // TODO: Add more collation-aware string expressions. /** @@ -566,333 +753,4 @@ public final class CollationSupport { // TODO: Add other collation-aware expressions. - /** - * Utility class for collation-aware UTF8String operations. - */ - - private static class CollationAwareUTF8String { - - private static UTF8String replace(final UTF8String src, final UTF8String search, - final UTF8String replace, final int collationId) { - // This collation aware implementation is based on existing implementation on UTF8String - if (src.numBytes() == 0 || search.numBytes() == 0) { - return src; - } - - StringSearch stringSearch = CollationFactory.getStringSearch(src, search, collationId); - - // Find the first occurrence of the search string. - int end = stringSearch.next(); - if (end == StringSearch.DONE) { - // Search string was not found, so string is unchanged. - return src; - } - - // Initialize byte positions - int c = 0; - int byteStart = 0; // position in byte - int byteEnd = 0; // position in byte - while (byteEnd < src.numBytes() && c < end) { - byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); - c += 1; - } - - // At least one match was found. Estimate space needed for result. - // The 16x multiplier here is chosen to match commons-lang3's implementation. - int increase = Math.max(0, Math.abs(replace.numBytes() - search.numBytes())) * 16; - final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); - while (end != StringSearch.DONE) { - buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); - buf.append(replace); - - // Move byteStart to the beginning of the current match - byteStart = byteEnd; - int cs = c; - // Move cs to the end of the current match - // This is necessary because the search string may contain 'multi-character' characters - while (byteStart < src.numBytes() && cs < c + stringSearch.getMatchLength()) { - byteStart += UTF8String.numBytesForFirstByte(src.getByte(byteStart)); - cs += 1; - } - // Go to next match - end = stringSearch.next(); - // Update byte positions - while (byteEnd < src.numBytes() && c < end) { - byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); - c += 1; - } - } - buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, - src.numBytes() - byteStart); - return buf.build(); - } - - private static UTF8String lowercaseReplace(final UTF8String src, final UTF8String search, - final UTF8String replace) { - if (src.numBytes() == 0 || search.numBytes() == 0) { - return src; - } - UTF8String lowercaseString = src.toLowerCase(); - UTF8String lowercaseSearch = search.toLowerCase(); - - int start = 0; - int end = lowercaseString.indexOf(lowercaseSearch, 0); - if (end == -1) { - // Search string was not found, so string is unchanged. - return src; - } - - // Initialize byte positions - int c = 0; - int byteStart = 0; // position in byte - int byteEnd = 0; // position in byte - while (byteEnd < src.numBytes() && c < end) { - byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); - c += 1; - } - - // At least one match was found. Estimate space needed for result. - // The 16x multiplier here is chosen to match commons-lang3's implementation. - int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16; - final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase); - while (end != -1) { - buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart); - buf.append(replace); - // Update character positions - start = end + lowercaseSearch.numChars(); - end = lowercaseString.indexOf(lowercaseSearch, start); - // Update byte positions - byteStart = byteEnd + search.numBytes(); - while (byteEnd < src.numBytes() && c < end) { - byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd)); - c += 1; - } - } - buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, - src.numBytes() - byteStart); - return buf.build(); - } - - private static String toUpperCase(final String target, final int collationId) { - ULocale locale = CollationFactory.fetchCollation(collationId) - .collator.getLocale(ULocale.ACTUAL_LOCALE); - return UCharacter.toUpperCase(locale, target); - } - - private static String toLowerCase(final String target, final int collationId) { - ULocale locale = CollationFactory.fetchCollation(collationId) - .collator.getLocale(ULocale.ACTUAL_LOCALE); - return UCharacter.toLowerCase(locale, target); - } - - private static String toTitleCase(final String target, final int collationId) { - ULocale locale = CollationFactory.fetchCollation(collationId) - .collator.getLocale(ULocale.ACTUAL_LOCALE); - return UCharacter.toTitleCase(locale, target, BreakIterator.getWordInstance(locale)); - } - - private static int findInSet(final UTF8String match, final UTF8String set, int collationId) { - if (match.contains(UTF8String.fromString(","))) { - return 0; - } - - String setString = set.toString(); - StringSearch stringSearch = CollationFactory.getStringSearch(setString, match.toString(), - collationId); - - int wordStart = 0; - while ((wordStart = stringSearch.next()) != StringSearch.DONE) { - boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ','; - boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length() - || setString.charAt(wordStart + stringSearch.getMatchLength()) == ','; - - if (isValidStart && isValidEnd) { - int pos = 0; - for (int i = 0; i < setString.length() && i < wordStart; i++) { - if (setString.charAt(i) == ',') { - pos++; - } - } - - return pos + 1; - } - } - - return 0; - } - - private static int indexOf(final UTF8String target, final UTF8String pattern, - final int start, final int collationId) { - if (pattern.numBytes() == 0) { - return 0; - } - - StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); - stringSearch.setIndex(start); - - return stringSearch.next(); - } - - private static int find(UTF8String target, UTF8String pattern, int start, - int collationId) { - assert (pattern.numBytes() > 0); - - StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); - // Set search start position (start from character at start position) - stringSearch.setIndex(target.bytePosToChar(start)); - - // Return either the byte position or -1 if not found - return target.charPosToByte(stringSearch.next()); - } - - private static UTF8String subStringIndex(final UTF8String string, final UTF8String delimiter, - int count, final int collationId) { - if (delimiter.numBytes() == 0 || count == 0 || string.numBytes() == 0) { - return UTF8String.EMPTY_UTF8; - } - if (count > 0) { - int idx = -1; - while (count > 0) { - idx = find(string, delimiter, idx + 1, collationId); - if (idx >= 0) { - count --; - } else { - // can not find enough delim - return string; - } - } - if (idx == 0) { - return UTF8String.EMPTY_UTF8; - } - byte[] bytes = new byte[idx]; - copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx); - return UTF8String.fromBytes(bytes); - - } else { - count = -count; - - StringSearch stringSearch = CollationFactory - .getStringSearch(string, delimiter, collationId); - - int start = string.numChars() - 1; - int lastMatchLength = 0; - int prevStart = -1; - while (count > 0) { - stringSearch.reset(); - prevStart = -1; - int matchStart = stringSearch.next(); - lastMatchLength = stringSearch.getMatchLength(); - while (matchStart <= start) { - if (matchStart != StringSearch.DONE) { - // Found a match, update the start position - prevStart = matchStart; - matchStart = stringSearch.next(); - } else { - break; - } - } - - if (prevStart == -1) { - // can not find enough delim - return string; - } else { - start = prevStart - 1; - count--; - } - } - - int resultStart = prevStart + lastMatchLength; - if (resultStart == string.numChars()) { - return UTF8String.EMPTY_UTF8; - } - - return string.substring(resultStart, string.numChars()); - } - } - - private static UTF8String lowercaseSubStringIndex(final UTF8String string, - final UTF8String delimiter, int count) { - if (delimiter.numBytes() == 0 || count == 0) { - return UTF8String.EMPTY_UTF8; - } - - UTF8String lowercaseString = string.toLowerCase(); - UTF8String lowercaseDelimiter = delimiter.toLowerCase(); - - if (count > 0) { - int idx = -1; - while (count > 0) { - idx = lowercaseString.find(lowercaseDelimiter, idx + 1); - if (idx >= 0) { - count --; - } else { - // can not find enough delim - return string; - } - } - if (idx == 0) { - return UTF8String.EMPTY_UTF8; - } - byte[] bytes = new byte[idx]; - copyMemory(string.getBaseObject(), string.getBaseOffset(), bytes, BYTE_ARRAY_OFFSET, idx); - return UTF8String.fromBytes(bytes); - - } else { - int idx = string.numBytes() - delimiter.numBytes() + 1; - count = -count; - while (count > 0) { - idx = lowercaseString.rfind(lowercaseDelimiter, idx - 1); - if (idx >= 0) { - count --; - } else { - // can not find enough delim - return string; - } - } - if (idx + delimiter.numBytes() == string.numBytes()) { - return UTF8String.EMPTY_UTF8; - } - int size = string.numBytes() - delimiter.numBytes() - idx; - byte[] bytes = new byte[size]; - copyMemory(string.getBaseObject(), string.getBaseOffset() + idx + delimiter.numBytes(), - bytes, BYTE_ARRAY_OFFSET, size); - return UTF8String.fromBytes(bytes); - } - } - - private static Map<String, String> getCollationAwareDict(UTF8String string, - Map<String, String> dict, int collationId) { - String srcStr = string.toString(); - - Map<String, String> collationAwareDict = new HashMap<>(); - for (String key : dict.keySet()) { - StringSearch stringSearch = - CollationFactory.getStringSearch(string, UTF8String.fromString(key), collationId); - - int pos = 0; - while ((pos = stringSearch.next()) != StringSearch.DONE) { - int codePoint = srcStr.codePointAt(pos); - int charCount = Character.charCount(codePoint); - String newKey = srcStr.substring(pos, pos + charCount); - - boolean exists = false; - for (String existingKey : collationAwareDict.keySet()) { - if (stringSearch.getCollator().compare(existingKey, newKey) == 0) { - collationAwareDict.put(newKey, collationAwareDict.get(existingKey)); - exists = true; - break; - } - } - - if (!exists) { - collationAwareDict.put(newKey, dict.get(key)); - } - } - } - - return collationAwareDict; - } - - } - } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 2a5d14580353..20b26b6ebc5a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -646,7 +646,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, * @param end the end position of the current UTF8String in bytes. * @return a new UTF8String in the position of [start, end] of current UTF8String bytes. */ - private UTF8String copyUTF8String(int start, int end) { + public UTF8String copyUTF8String(int start, int end) { int len = end - start + 1; byte[] newBytes = new byte[len]; copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 2f05b9ad88c9..7fc3c4e349c3 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -800,6 +800,199 @@ public class CollationSupportSuite { assertSubstringIndex("ai̇bİoi̇o12i̇oİo", "i̇o", -4, "UNICODE_CI", "i̇o12i̇oİo"); } + private void assertStringTrim( + String collation, + String sourceString, + String trimString, + String expectedResultString) throws SparkException { + int collationId = CollationFactory.collationNameToId(collation); + String result; + + if (trimString == null) { + result = CollationSupport.StringTrim.exec( + UTF8String.fromString(sourceString), collationId).toString(); + } else { + result = CollationSupport.StringTrim.exec( + UTF8String + .fromString(sourceString), UTF8String.fromString(trimString), collationId) + .toString(); + } + + assertEquals(expectedResultString, result); + } + + private void assertStringTrimLeft( + String collation, + String sourceString, + String trimString, + String expectedResultString) throws SparkException { + int collationId = CollationFactory.collationNameToId(collation); + String result; + + if (trimString == null) { + result = CollationSupport.StringTrimLeft.exec( + UTF8String.fromString(sourceString), collationId).toString(); + } else { + result = CollationSupport.StringTrimLeft.exec( + UTF8String + .fromString(sourceString), UTF8String.fromString(trimString), collationId) + .toString(); + } + + assertEquals(expectedResultString, result); + } + + private void assertStringTrimRight( + String collation, + String sourceString, + String trimString, + String expectedResultString) throws SparkException { + int collationId = CollationFactory.collationNameToId(collation); + String result; + + if (trimString == null) { + result = CollationSupport.StringTrimRight.exec( + UTF8String.fromString(sourceString), collationId).toString(); + } else { + result = CollationSupport.StringTrimRight.exec( + UTF8String + .fromString(sourceString), UTF8String.fromString(trimString), collationId) + .toString(); + } + + assertEquals(expectedResultString, result); + } + + @Test + public void testStringTrim() throws SparkException { + assertStringTrim("UTF8_BINARY", "asd", null, "asd"); + assertStringTrim("UTF8_BINARY", " asd ", null, "asd"); + assertStringTrim("UTF8_BINARY", " a世a ", null, "a世a"); + assertStringTrim("UTF8_BINARY", "asd", "x", "asd"); + assertStringTrim("UTF8_BINARY", "xxasdxx", "x", "asd"); + assertStringTrim("UTF8_BINARY", "xa世ax", "x", "a世a"); + + assertStringTrimLeft("UTF8_BINARY", "asd", null, "asd"); + assertStringTrimLeft("UTF8_BINARY", " asd ", null, "asd "); + assertStringTrimLeft("UTF8_BINARY", " a世a ", null, "a世a "); + assertStringTrimLeft("UTF8_BINARY", "asd", "x", "asd"); + assertStringTrimLeft("UTF8_BINARY", "xxasdxx", "x", "asdxx"); + assertStringTrimLeft("UTF8_BINARY", "xa世ax", "x", "a世ax"); + + assertStringTrimRight("UTF8_BINARY", "asd", null, "asd"); + assertStringTrimRight("UTF8_BINARY", " asd ", null, " asd"); + assertStringTrimRight("UTF8_BINARY", " a世a ", null, " a世a"); + assertStringTrimRight("UTF8_BINARY", "asd", "x", "asd"); + assertStringTrimRight("UTF8_BINARY", "xxasdxx", "x", "xxasd"); + assertStringTrimRight("UTF8_BINARY", "xa世ax", "x", "xa世a"); + + assertStringTrim("UTF8_BINARY_LCASE", "asd", null, "asd"); + assertStringTrim("UTF8_BINARY_LCASE", " asd ", null, "asd"); + assertStringTrim("UTF8_BINARY_LCASE", " a世a ", null, "a世a"); + assertStringTrim("UTF8_BINARY_LCASE", "asd", "x", "asd"); + assertStringTrim("UTF8_BINARY_LCASE", "xxasdxx", "x", "asd"); + assertStringTrim("UTF8_BINARY_LCASE", "xa世ax", "x", "a世a"); + + assertStringTrimLeft("UTF8_BINARY_LCASE", "asd", null, "asd"); + assertStringTrimLeft("UTF8_BINARY_LCASE", " asd ", null, "asd "); + assertStringTrimLeft("UTF8_BINARY_LCASE", " a世a ", null, "a世a "); + assertStringTrimLeft("UTF8_BINARY_LCASE", "asd", "x", "asd"); + assertStringTrimLeft("UTF8_BINARY_LCASE", "xxasdxx", "x", "asdxx"); + assertStringTrimLeft("UTF8_BINARY_LCASE", "xa世ax", "x", "a世ax"); + + assertStringTrimRight("UTF8_BINARY_LCASE", "asd", null, "asd"); + assertStringTrimRight("UTF8_BINARY_LCASE", " asd ", null, " asd"); + assertStringTrimRight("UTF8_BINARY_LCASE", " a世a ", null, " a世a"); + assertStringTrimRight("UTF8_BINARY_LCASE", "asd", "x", "asd"); + assertStringTrimRight("UTF8_BINARY_LCASE", "xxasdxx", "x", "xxasd"); + assertStringTrimRight("UTF8_BINARY_LCASE", "xa世ax", "x", "xa世a"); + + assertStringTrim("UTF8_BINARY_LCASE", "asd", null, "asd"); + assertStringTrim("UTF8_BINARY_LCASE", " asd ", null, "asd"); + assertStringTrim("UTF8_BINARY_LCASE", " a世a ", null, "a世a"); + assertStringTrim("UTF8_BINARY_LCASE", "asd", "x", "asd"); + assertStringTrim("UTF8_BINARY_LCASE", "xxasdxx", "x", "asd"); + assertStringTrim("UTF8_BINARY_LCASE", "xa世ax", "x", "a世a"); + + assertStringTrimLeft("UNICODE", "asd", null, "asd"); + assertStringTrimLeft("UNICODE", " asd ", null, "asd "); + assertStringTrimLeft("UNICODE", " a世a ", null, "a世a "); + assertStringTrimLeft("UNICODE", "asd", "x", "asd"); + assertStringTrimLeft("UNICODE", "xxasdxx", "x", "asdxx"); + assertStringTrimLeft("UNICODE", "xa世ax", "x", "a世ax"); + + assertStringTrimRight("UNICODE", "asd", null, "asd"); + assertStringTrimRight("UNICODE", " asd ", null, " asd"); + assertStringTrimRight("UNICODE", " a世a ", null, " a世a"); + assertStringTrimRight("UNICODE", "asd", "x", "asd"); + assertStringTrimRight("UNICODE", "xxasdxx", "x", "xxasd"); + assertStringTrimRight("UNICODE", "xa世ax", "x", "xa世a"); + + // Test cases where trimString has more than one character + assertStringTrim("UTF8_BINARY", "ddsXXXaa", "asd", "XXX"); + assertStringTrimLeft("UTF8_BINARY", "ddsXXXaa", "asd", "XXXaa"); + assertStringTrimRight("UTF8_BINARY", "ddsXXXaa", "asd", "ddsXXX"); + + assertStringTrim("UTF8_BINARY_LCASE", "ddsXXXaa", "asd", "XXX"); + assertStringTrimLeft("UTF8_BINARY_LCASE", "ddsXXXaa", "asd", "XXXaa"); + assertStringTrimRight("UTF8_BINARY_LCASE", "ddsXXXaa", "asd", "ddsXXX"); + + assertStringTrim("UNICODE", "ddsXXXaa", "asd", "XXX"); + assertStringTrimLeft("UNICODE", "ddsXXXaa", "asd", "XXXaa"); + assertStringTrimRight("UNICODE", "ddsXXXaa", "asd", "ddsXXX"); + + // Test cases specific to collation type + // uppercase trim, lowercase src + assertStringTrim("UTF8_BINARY", "asd", "A", "asd"); + assertStringTrim("UTF8_BINARY_LCASE", "asd", "A", "sd"); + assertStringTrim("UNICODE", "asd", "A", "asd"); + assertStringTrim("UNICODE_CI", "asd", "A", "sd"); + + // lowercase trim, uppercase src + assertStringTrim("UTF8_BINARY", "ASD", "a", "ASD"); + assertStringTrim("UTF8_BINARY_LCASE", "ASD", "a", "SD"); + assertStringTrim("UNICODE", "ASD", "a", "ASD"); + assertStringTrim("UNICODE_CI", "ASD", "a", "SD"); + + // uppercase and lowercase chars of different byte-length (utf8) + assertStringTrim("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ"); + assertStringTrimLeft("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ"); + assertStringTrimRight("UTF8_BINARY", "ẞaaaẞ", "ß", "ẞaaaẞ"); + + assertStringTrim("UTF8_BINARY_LCASE", "ẞaaaẞ", "ß", "aaa"); + assertStringTrimLeft("UTF8_BINARY_LCASE", "ẞaaaẞ", "ß", "aaaẞ"); + assertStringTrimRight("UTF8_BINARY_LCASE", "ẞaaaẞ", "ß", "ẞaaa"); + + assertStringTrim("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ"); + assertStringTrimLeft("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ"); + assertStringTrimRight("UNICODE", "ẞaaaẞ", "ß", "ẞaaaẞ"); + + assertStringTrim("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß"); + assertStringTrimLeft("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß"); + assertStringTrimRight("UTF8_BINARY", "ßaaaß", "ẞ", "ßaaaß"); + + assertStringTrim("UTF8_BINARY_LCASE", "ßaaaß", "ẞ", "aaa"); + assertStringTrimLeft("UTF8_BINARY_LCASE", "ßaaaß", "ẞ", "aaaß"); + assertStringTrimRight("UTF8_BINARY_LCASE", "ßaaaß", "ẞ", "ßaaa"); + + assertStringTrim("UNICODE", "ßaaaß", "ẞ", "ßaaaß"); + assertStringTrimLeft("UNICODE", "ßaaaß", "ẞ", "ßaaaß"); + assertStringTrimRight("UNICODE", "ßaaaß", "ẞ", "ßaaaß"); + + // different byte-length (utf8) chars trimmed + assertStringTrim("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "aaa"); + assertStringTrimLeft("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "aaaẞ"); + assertStringTrimRight("UTF8_BINARY", "Ëaaaẞ", "Ëẞ", "Ëaaa"); + + assertStringTrim("UTF8_BINARY_LCASE", "Ëaaaẞ", "Ëẞ", "aaa"); + assertStringTrimLeft("UTF8_BINARY_LCASE", "Ëaaaẞ", "Ëẞ", "aaaẞ"); + assertStringTrimRight("UTF8_BINARY_LCASE", "Ëaaaẞ", "Ëẞ", "Ëaaa"); + + assertStringTrim("UNICODE", "Ëaaaẞ", "Ëẞ", "aaa"); + assertStringTrimLeft("UNICODE", "Ëaaaẞ", "Ëẞ", "aaaẞ"); + assertStringTrimRight("UNICODE", "Ëaaaẞ", "Ëẞ", "Ëaaa"); + } + // TODO: Test more collation-aware string expressions. /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 44349384187e..a50dad7c8cdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -74,7 +74,7 @@ object CollationTypeCasts extends TypeCoercionRule { case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: StringReplace | - _: StringTranslate) => + _: StringTranslate | _: StringTrim | _: StringTrimLeft | _: StringTrimRight) => val newChildren = collateToSingleType(otherExpr.children) otherExpr.withNewChildren(newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 0bdd7930b0bf..09ec501311ad 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods @@ -1020,8 +1020,10 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { protected def direction: String override def children: Seq[Expression] = srcStr +: trimStr.toSeq - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = srcStr.dataType + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringTypeBinaryLcase) + + final lazy val collationId: Int = srcStr.dataType.asInstanceOf[StringType].collationId override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -1040,13 +1042,19 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { } } - protected val trimMethod: String - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evals = children.map(_.genCode(ctx)) - val srcString = evals(0) + val srcString = evals.head if (evals.length == 1) { + val stringTrimCode: String = this match { + case _: StringTrim => + CollationSupport.StringTrim.genCode(srcString.value, collationId) + case _: StringTrimLeft => + CollationSupport.StringTrimLeft.genCode(srcString.value, collationId) + case _: StringTrimRight => + CollationSupport.StringTrimRight.genCode(srcString.value, collationId) + } ev.copy(code = code""" |${srcString.code} |boolean ${ev.isNull} = false; @@ -1054,10 +1062,18 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { |if (${srcString.isNull}) { | ${ev.isNull} = true; |} else { - | ${ev.value} = ${srcString.value}.$trimMethod(); + | ${ev.value} = $stringTrimCode; |}""".stripMargin) } else { val trimString = evals(1) + val stringTrimCode: String = this match { + case _: StringTrim => + CollationSupport.StringTrim.genCode(srcString.value, trimString.value, collationId) + case _: StringTrimLeft => + CollationSupport.StringTrimLeft.genCode(srcString.value, trimString.value, collationId) + case _: StringTrimRight => + CollationSupport.StringTrimRight.genCode(srcString.value, trimString.value, collationId) + } ev.copy(code = code""" |${srcString.code} |boolean ${ev.isNull} = false; @@ -1069,7 +1085,7 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { | if (${trimString.isNull}) { | ${ev.isNull} = true; | } else { - | ${ev.value} = ${srcString.value}.$trimMethod(${trimString.value}); + | ${ev.value} = $stringTrimCode; | } |}""".stripMargin) } @@ -1162,12 +1178,11 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) override protected def direction: String = "BOTH" - override def doEval(srcString: UTF8String): UTF8String = srcString.trim() + override def doEval(srcString: UTF8String): UTF8String = + CollationSupport.StringTrim.exec(srcString, collationId) override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = - srcString.trim(trimString) - - override val trimMethod: String = "trim" + CollationSupport.StringTrim.exec(srcString, trimString, collationId) override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy( @@ -1270,12 +1285,11 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None override protected def direction: String = "LEADING" - override def doEval(srcString: UTF8String): UTF8String = srcString.trimLeft() + override def doEval(srcString: UTF8String): UTF8String = + CollationSupport.StringTrimLeft.exec(srcString, collationId) override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = - srcString.trimLeft(trimString) - - override val trimMethod: String = "trimLeft" + CollationSupport.StringTrimLeft.exec(srcString, trimString, collationId) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimLeft = @@ -1331,12 +1345,11 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non override protected def direction: String = "TRAILING" - override def doEval(srcString: UTF8String): UTF8String = srcString.trimRight() + override def doEval(srcString: UTF8String): UTF8String = + CollationSupport.StringTrimRight.exec(srcString, collationId) override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = - srcString.trimRight(trimString) - - override val trimMethod: String = "trimRight" + CollationSupport.StringTrimRight.exec(srcString, trimString, collationId) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimRight = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index b9a4fecd0465..9cc123b708af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal, StringTrim, StringTrimLeft, StringTrimRight} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -26,7 +27,8 @@ import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, // scalastyle:off nonascii class CollationStringExpressionsSuite extends QueryTest - with SharedSparkSession { + with SharedSparkSession + with ExpressionEvalHelper { test("Support ConcatWs string expression with collation") { // Supported collations @@ -800,6 +802,163 @@ class CollationStringExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } + test("StringTrim* functions - unit tests for both paths (codegen and eval)") { + // Without trimString param. + checkEvaluation(StringTrim(Literal.create( " asd ", StringType("UTF8_BINARY"))), "asd") + checkEvaluation( + StringTrimLeft(Literal.create(" asd ", StringType("UTF8_BINARY_LCASE"))), "asd ") + checkEvaluation(StringTrimRight(Literal.create(" asd ", StringType("UNICODE"))), " asd") + + // With trimString param. + checkEvaluation( + StringTrim( + Literal.create(" asd ", StringType("UTF8_BINARY")), + Literal.create(" ", StringType("UTF8_BINARY"))), + "asd") + checkEvaluation( + StringTrimLeft( + Literal.create(" asd ", StringType("UTF8_BINARY_LCASE")), + Literal.create(" ", StringType("UTF8_BINARY_LCASE"))), + "asd ") + checkEvaluation( + StringTrimRight( + Literal.create(" asd ", StringType("UNICODE")), + Literal.create(" ", StringType("UNICODE"))), + " asd") + + checkEvaluation( + StringTrim( + Literal.create("xxasdxx", StringType("UTF8_BINARY")), + Literal.create("x", StringType("UTF8_BINARY"))), + "asd") + checkEvaluation( + StringTrimLeft( + Literal.create("xxasdxx", StringType("UTF8_BINARY_LCASE")), + Literal.create("x", StringType("UTF8_BINARY_LCASE"))), + "asdxx") + checkEvaluation( + StringTrimRight( + Literal.create("xxasdxx", StringType("UNICODE")), + Literal.create("x", StringType("UNICODE"))), + "xxasd") + } + + test("StringTrim* functions - E2E tests") { + case class StringTrimTestCase( + collation: String, + trimFunc: String, + sourceString: String, + hasTrimString: Boolean, + trimString: String, + expectedResultString: String) + + val testCases = Seq( + StringTrimTestCase("UTF8_BINARY", "TRIM", " asd ", false, null, "asd"), + StringTrimTestCase("UTF8_BINARY", "BTRIM", " asd ", true, null, null), + StringTrimTestCase("UTF8_BINARY", "LTRIM", "xxasdxx", true, "x", "asdxx"), + StringTrimTestCase("UTF8_BINARY", "RTRIM", "xxasdxx", true, "x", "xxasd"), + + StringTrimTestCase("UTF8_BINARY_LCASE", "TRIM", " asd ", true, null, null), + StringTrimTestCase("UTF8_BINARY_LCASE", "BTRIM", "xxasdxx", true, "x", "asd"), + StringTrimTestCase("UTF8_BINARY_LCASE", "LTRIM", "xxasdxx", true, "x", "asdxx"), + StringTrimTestCase("UTF8_BINARY_LCASE", "RTRIM", " asd ", false, null, " asd"), + + StringTrimTestCase("UNICODE", "TRIM", "xxasdxx", true, "x", "asd"), + StringTrimTestCase("UNICODE", "BTRIM", "xxasdxx", true, "x", "asd"), + StringTrimTestCase("UNICODE", "LTRIM", " asd ", false, null, "asd "), + StringTrimTestCase("UNICODE", "RTRIM", " asd ", true, null, null) + + // Other more complex cases can be found in unit tests in CollationSupportSuite.java. + ) + + testCases.foreach(testCase => { + var df: DataFrame = null + + if (testCase.trimFunc.equalsIgnoreCase("BTRIM")) { + // BTRIM has arguments in (srcStr, trimStr) order + df = sql(s"SELECT ${testCase.trimFunc}(" + + s"COLLATE('${testCase.sourceString}', '${testCase.collation}')" + + (if (!testCase.hasTrimString) "" + else if (testCase.trimString == null) ", null" + else s", '${testCase.trimString}'") + + ")") + } + else { + // While other functions have arguments in (trimStr, srcStr) order + df = sql(s"SELECT ${testCase.trimFunc}(" + + (if (!testCase.hasTrimString) "" + else if (testCase.trimString == null) "null, " + else s"'${testCase.trimString}', ") + + s"COLLATE('${testCase.sourceString}', '${testCase.collation}')" + + ")") + } + + checkAnswer(df = df, expectedAnswer = Row(testCase.expectedResultString)) + }) + } + + test("StringTrim* functions - implicit collations") { + checkAnswer( + df = sql("SELECT TRIM(COLLATE('x', 'UTF8_BINARY'), COLLATE('xax', 'UTF8_BINARY'))"), + expectedAnswer = Row("a")) + checkAnswer( + df = sql("SELECT BTRIM(COLLATE('xax', 'UTF8_BINARY_LCASE'), " + + "COLLATE('x', 'UTF8_BINARY_LCASE'))"), + expectedAnswer = Row("a")) + checkAnswer( + df = sql("SELECT LTRIM(COLLATE('x', 'UNICODE'), COLLATE('xax', 'UNICODE'))"), + expectedAnswer = Row("ax")) + + checkAnswer( + df = sql("SELECT RTRIM('x', COLLATE('xax', 'UTF8_BINARY'))"), + expectedAnswer = Row("xa")) + checkAnswer( + df = sql("SELECT TRIM('x', COLLATE('xax', 'UTF8_BINARY_LCASE'))"), + expectedAnswer = Row("a")) + checkAnswer( + df = sql("SELECT BTRIM('xax', COLLATE('x', 'UNICODE'))"), + expectedAnswer = Row("a")) + + checkAnswer( + df = sql("SELECT LTRIM(COLLATE('x', 'UTF8_BINARY'), 'xax')"), + expectedAnswer = Row("ax")) + checkAnswer( + df = sql("SELECT RTRIM(COLLATE('x', 'UTF8_BINARY_LCASE'), 'xax')"), + expectedAnswer = Row("xa")) + checkAnswer( + df = sql("SELECT TRIM(COLLATE('x', 'UNICODE'), 'xax')"), + expectedAnswer = Row("a")) + } + + test("StringTrim* functions - collation type mismatch") { + List("TRIM", "LTRIM", "RTRIM").foreach(func => { + val collationMismatch = intercept[AnalysisException] { + sql("SELECT " + func + "(COLLATE('x', 'UTF8_BINARY_LCASE'), " + + "COLLATE('xxaaaxx', 'UNICODE'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + }) + + val collationMismatch = intercept[AnalysisException] { + sql("SELECT BTRIM(COLLATE('xxaaaxx', 'UNICODE'), COLLATE('x', 'UTF8_BINARY_LCASE'))") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } + + test("StringTrim* functions - unsupported collation types") { + List("TRIM", "LTRIM", "RTRIM").foreach(func => { + val collationMismatch = intercept[AnalysisException] { + sql("SELECT " + func + "(COLLATE('x', 'UNICODE_CI'), COLLATE('xxaaaxx', 'UNICODE_CI'))") + } + assert(collationMismatch.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + }) + + val collationMismatch = intercept[AnalysisException] { + sql("SELECT BTRIM(COLLATE('xxaaaxx', 'UNICODE_CI'), COLLATE('x', 'UNICODE_CI'))") + } + assert(collationMismatch.getErrorClass === "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE") + } + // TODO: Add more tests for other string expressions } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org