This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7fe1b93884aa [SPARK-46841][SQL] Add collation support for ICU locales and collation specifiers 7fe1b93884aa is described below commit 7fe1b93884aa8e9ba20f19351b8537c687b8f59c Author: Nikola Mandic <nikola.man...@databricks.com> AuthorDate: Tue May 28 09:56:16 2024 -0700 [SPARK-46841][SQL] Add collation support for ICU locales and collation specifiers ### What changes were proposed in this pull request? Languages and localization for collations are supported by ICU library. Collation naming format is as follows: ``` <2-letter language code>[_<4-letter script>][_<3-letter country code>][_specifier_specifier...] ``` Locale specifier consists of the first part of collation name (language + script + country). Locale specifiers need to be stable across ICU versions; to keep existing ids and names invariant we introduce golden file will locale table which should case CI failure on any silent changes. Currently supported optional specifiers: - `CS`/`CI` - case sensitivity, default is case-sensitive; supported by configuring ICU collation levels - `AS`/`AI` - accent sensitivity, default is accent-sensitive; supported by configuring ICU collation levels User can use collation specifiers in any order except of locale which is mandatory and must go first. There is a one-to-one mapping between collation ids and collation names defined in `CollationFactory`. ### Why are the changes needed? To add languages and localization support for collations. ### Does this PR introduce _any_ user-facing change? Yes, it adds new predefined collations. ### How was this patch tested? Added checks to `CollationFactorySuite` and ICU locale map golden file. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46180 from nikolamand-db/SPARK-46841. Authored-by: Nikola Mandic <nikola.man...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/util/CollationFactory.java | 678 +++++++++++++++++---- .../spark/unsafe/types/CollationFactorySuite.scala | 323 +++++++++- .../src/main/resources/error/error-conditions.json | 4 +- .../apache/spark/sql/PlanGenerationTestSuite.scala | 4 +- .../src/main/protobuf/spark/connect/types.proto | 2 +- .../connect/common/DataTypeProtoConverter.scala | 9 +- .../query-tests/queries/csv_from_dataset.json | 2 +- .../query-tests/queries/csv_from_dataset.proto.bin | Bin 158 -> 169 bytes .../query-tests/queries/function_lit_array.json | 4 +- .../queries/function_lit_array.proto.bin | Bin 889 -> 911 bytes .../query-tests/queries/function_typedLit.json | 32 +- .../queries/function_typedLit.proto.bin | Bin 1199 -> 1381 bytes .../query-tests/queries/json_from_dataset.json | 2 +- .../queries/json_from_dataset.proto.bin | Bin 169 -> 180 bytes python/pyspark/sql/connect/proto/types_pb2.py | 78 +-- python/pyspark/sql/connect/proto/types_pb2.pyi | 11 +- python/pyspark/sql/connect/types.py | 5 +- python/pyspark/sql/types.py | 27 +- .../org/apache/spark/sql/internal/SQLConf.scala | 15 +- .../expressions/CollationExpressionSuite.scala | 33 +- .../resources/collations/ICU-collations-map.md | 143 +++++ .../sql-tests/analyzer-results/collations.sql.out | 77 +++ .../test/resources/sql-tests/inputs/collations.sql | 13 + .../resources/sql-tests/results/collations.sql.out | 88 +++ .../org/apache/spark/sql/CollationSuite.scala | 2 +- .../apache/spark/sql/ICUCollationsMapSuite.scala | 69 +++ .../apache/spark/sql/internal/SQLConfSuite.scala | 3 +- 27 files changed, 1388 insertions(+), 236 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 0133c3feb611..fce12510afaf 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.util; import java.text.CharacterIterator; import java.text.StringCharacterIterator; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiFunction; import java.util.function.ToLongFunction; @@ -173,26 +174,546 @@ public final class CollationFactory { } /** - * Constructor with comparators that are inherited from the given collator. + * Collation ID is defined as 32-bit integer. We specify binary layouts for different classes of + * collations. Classes of collations are differentiated by most significant 3 bits (bit 31, 30 + * and 29), bit 31 being most significant and bit 0 being least significant. + * --- + * General collation ID binary layout: + * bit 31: 1 for INDETERMINATE (requires all other bits to be 1 as well), 0 otherwise. + * bit 30: 0 for predefined, 1 for user-defined. + * Following bits are specified for predefined collations: + * bit 29: 0 for UTF8_BINARY, 1 for ICU collations. + * bit 28-24: Reserved. + * bit 23-22: Reserved for version. + * bit 21-18: Reserved for space trimming. + * bit 17-0: Depend on collation family. + * --- + * INDETERMINATE collation ID binary layout: + * bit 31-0: 1 + * INDETERMINATE collation ID is equal to -1. + * --- + * User-defined collation ID binary layout: + * bit 31: 0 + * bit 30: 1 + * bit 29-0: Undefined, reserved for future use. + * --- + * UTF8_BINARY collation ID binary layout: + * bit 31-24: Zeroes. + * bit 23-22: Zeroes, reserved for version. + * bit 21-18: Zeroes, reserved for space trimming. + * bit 17-3: Zeroes. + * bit 2: 0, reserved for accent sensitivity. + * bit 1: 0, reserved for uppercase and case-insensitive. + * bit 0: 0 = case-sensitive, 1 = lowercase. + * --- + * ICU collation ID binary layout: + * bit 31-30: Zeroes. + * bit 29: 1 + * bit 28-24: Zeroes. + * bit 23-22: Zeroes, reserved for version. + * bit 21-18: Zeroes, reserved for space trimming. + * bit 17: 0 = case-sensitive, 1 = case-insensitive. + * bit 16: 0 = accent-sensitive, 1 = accent-insensitive. + * bit 15-14: Zeroes, reserved for punctuation sensitivity. + * bit 13-12: Zeroes, reserved for first letter preference. + * bit 11-0: Locale ID as specified in `ICULocaleToId` mapping. + * --- + * Some illustrative examples of collation name to ID mapping: + * - UTF8_BINARY -> 0 + * - UTF8_BINARY_LCASE -> 1 + * - UNICODE -> 0x20000000 + * - UNICODE_AI -> 0x20010000 + * - UNICODE_CI -> 0x20020000 + * - UNICODE_CI_AI -> 0x20030000 + * - af -> 0x20000001 + * - af_CI_AI -> 0x20030001 */ - public Collation( - String collationName, - String provider, - Collator collator, - String version, - boolean supportsBinaryEquality, - boolean supportsBinaryOrdering, - boolean supportsLowercaseEquality) { - this( - collationName, - provider, - collator, - (s1, s2) -> collator.compare(s1.toString(), s2.toString()), - version, - s -> (long)collator.getCollationKey(s.toString()).hashCode(), - supportsBinaryEquality, - supportsBinaryOrdering, - supportsLowercaseEquality); + private abstract static class CollationSpec { + + /** + * Bit 30 in collation ID having value 0 for predefined and 1 for user-defined collation. + */ + private enum DefinitionOrigin { + PREDEFINED, USER_DEFINED + } + + /** + * Bit 29 in collation ID having value 0 for UTF8_BINARY family and 1 for ICU family of + * collations. + */ + protected enum ImplementationProvider { + UTF8_BINARY, ICU + } + + /** + * Offset in binary collation ID layout. + */ + private static final int DEFINITION_ORIGIN_OFFSET = 30; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + private static final int DEFINITION_ORIGIN_MASK = 0b1; + + /** + * Offset in binary collation ID layout. + */ + protected static final int IMPLEMENTATION_PROVIDER_OFFSET = 29; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + protected static final int IMPLEMENTATION_PROVIDER_MASK = 0b1; + + private static final int INDETERMINATE_COLLATION_ID = -1; + + /** + * Thread-safe cache mapping collation IDs to corresponding `Collation` instances. + * We add entries to this cache lazily as new `Collation` instances are requested. + */ + private static final Map<Integer, Collation> collationMap = new ConcurrentHashMap<>(); + + /** + * Utility function to retrieve `ImplementationProvider` enum instance from collation ID. + */ + private static ImplementationProvider getImplementationProvider(int collationId) { + return ImplementationProvider.values()[SpecifierUtils.getSpecValue(collationId, + IMPLEMENTATION_PROVIDER_OFFSET, IMPLEMENTATION_PROVIDER_MASK)]; + } + + /** + * Utility function to retrieve `DefinitionOrigin` enum instance from collation ID. + */ + private static DefinitionOrigin getDefinitionOrigin(int collationId) { + return DefinitionOrigin.values()[SpecifierUtils.getSpecValue(collationId, + DEFINITION_ORIGIN_OFFSET, DEFINITION_ORIGIN_MASK)]; + } + + /** + * Main entry point for retrieving `Collation` instance from collation ID. + */ + private static Collation fetchCollation(int collationId) { + // User-defined collations and INDETERMINATE collations cannot produce a `Collation` + // instance. + assert (collationId >= 0 && getDefinitionOrigin(collationId) + == DefinitionOrigin.PREDEFINED); + if (collationId == UTF8_BINARY_COLLATION_ID) { + // Skip cache. + return CollationSpecUTF8Binary.UTF8_BINARY_COLLATION; + } else if (collationMap.containsKey(collationId)) { + // Already in cache. + return collationMap.get(collationId); + } else { + // Build `Collation` instance and put into cache. + CollationSpec spec; + ImplementationProvider implementationProvider = getImplementationProvider(collationId); + if (implementationProvider == ImplementationProvider.UTF8_BINARY) { + spec = CollationSpecUTF8Binary.fromCollationId(collationId); + } else { + spec = CollationSpecICU.fromCollationId(collationId); + } + Collation collation = spec.buildCollation(); + collationMap.put(collationId, collation); + return collation; + } + } + + protected static SparkException collationInvalidNameException(String collationName) { + return new SparkException("COLLATION_INVALID_NAME", + SparkException.constructMessageParams(Map.of("collationName", collationName)), null); + } + + private static int collationNameToId(String collationName) throws SparkException { + // Collation names provided by user are treated as case-insensitive. + String collationNameUpper = collationName.toUpperCase(); + if (collationNameUpper.startsWith("UTF8_BINARY")) { + return CollationSpecUTF8Binary.collationNameToId(collationName, collationNameUpper); + } else { + return CollationSpecICU.collationNameToId(collationName, collationNameUpper); + } + } + + protected abstract Collation buildCollation(); + } + + private static class CollationSpecUTF8Binary extends CollationSpec { + + /** + * Bit 0 in collation ID having value 0 for plain UTF8_BINARY and 1 for UTF8_BINARY_LCASE + * collation. + */ + private enum CaseSensitivity { + UNSPECIFIED, LCASE + } + + /** + * Offset in binary collation ID layout. + */ + private static final int CASE_SENSITIVITY_OFFSET = 0; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + private static final int CASE_SENSITIVITY_MASK = 0b1; + + private static final int UTF8_BINARY_COLLATION_ID = + new CollationSpecUTF8Binary(CaseSensitivity.UNSPECIFIED).collationId; + private static final int UTF8_BINARY_LCASE_COLLATION_ID = + new CollationSpecUTF8Binary(CaseSensitivity.LCASE).collationId; + protected static Collation UTF8_BINARY_COLLATION = + new CollationSpecUTF8Binary(CaseSensitivity.UNSPECIFIED).buildCollation(); + protected static Collation UTF8_BINARY_LCASE_COLLATION = + new CollationSpecUTF8Binary(CaseSensitivity.LCASE).buildCollation(); + + private final int collationId; + + private CollationSpecUTF8Binary(CaseSensitivity caseSensitivity) { + this.collationId = + SpecifierUtils.setSpecValue(0, CASE_SENSITIVITY_OFFSET, caseSensitivity); + } + + private static int collationNameToId(String originalName, String collationName) + throws SparkException { + if (UTF8_BINARY_COLLATION.collationName.equals(collationName)) { + return UTF8_BINARY_COLLATION_ID; + } else if (UTF8_BINARY_LCASE_COLLATION.collationName.equals(collationName)) { + return UTF8_BINARY_LCASE_COLLATION_ID; + } else { + // Throw exception with original (before case conversion) collation name. + throw collationInvalidNameException(originalName); + } + } + + private static CollationSpecUTF8Binary fromCollationId(int collationId) { + // Extract case sensitivity from collation ID. + int caseConversionOrdinal = SpecifierUtils.getSpecValue(collationId, + CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); + // Verify only case sensitivity bits were set settable in UTF8_BINARY family of collations. + assert (SpecifierUtils.removeSpec(collationId, + CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK) == 0); + return new CollationSpecUTF8Binary(CaseSensitivity.values()[caseConversionOrdinal]); + } + + @Override + protected Collation buildCollation() { + if (collationId == UTF8_BINARY_COLLATION_ID) { + return new Collation( + "UTF8_BINARY", + PROVIDER_SPARK, + null, + UTF8String::binaryCompare, + "1.0", + s -> (long) s.hashCode(), + /* supportsBinaryEquality = */ true, + /* supportsBinaryOrdering = */ true, + /* supportsLowercaseEquality = */ false); + } else { + return new Collation( + "UTF8_BINARY_LCASE", + PROVIDER_SPARK, + null, + UTF8String::compareLowerCase, + "1.0", + s -> (long) s.toLowerCase().hashCode(), + /* supportsBinaryEquality = */ false, + /* supportsBinaryOrdering = */ false, + /* supportsLowercaseEquality = */ true); + } + } + } + + private static class CollationSpecICU extends CollationSpec { + + /** + * Bit 17 in collation ID having value 0 for case-sensitive and 1 for case-insensitive + * collation. + */ + private enum CaseSensitivity { + CS, CI + } + + /** + * Bit 16 in collation ID having value 0 for accent-sensitive and 1 for accent-insensitive + * collation. + */ + private enum AccentSensitivity { + AS, AI + } + + /** + * Offset in binary collation ID layout. + */ + private static final int CASE_SENSITIVITY_OFFSET = 17; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + private static final int CASE_SENSITIVITY_MASK = 0b1; + + /** + * Offset in binary collation ID layout. + */ + private static final int ACCENT_SENSITIVITY_OFFSET = 16; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + private static final int ACCENT_SENSITIVITY_MASK = 0b1; + + /** + * Array of locale names, each locale ID corresponds to the index in this array. + */ + private static final String[] ICULocaleNames; + + /** + * Mapping of locale names to corresponding `ULocale` instance. + */ + private static final Map<String, ULocale> ICULocaleMap = new HashMap<>(); + + /** + * Used to parse user input collation names which are converted to uppercase. + */ + private static final Map<String, String> ICULocaleMapUppercase = new HashMap<>(); + + /** + * Reverse mapping of `ICULocaleNames`. + */ + private static final Map<String, Integer> ICULocaleToId = new HashMap<>(); + + /** + * ICU library Collator version passed to `Collation` instance. + */ + private static final String ICU_COLLATOR_VERSION = "153.120.0.0"; + + static { + ICULocaleMap.put("UNICODE", ULocale.ROOT); + // ICU-implemented `ULocale`s which have corresponding `Collator` installed. + ULocale[] locales = Collator.getAvailableULocales(); + // Build locale names in format: language["_" optional script]["_" optional country code]. + // Examples: en, en_USA, sr_Cyrl_SRB + for (ULocale locale : locales) { + // Skip variants. + if (locale.getVariant().isEmpty()) { + String language = locale.getLanguage(); + // Require non-empty language as first component of locale name. + assert (!language.isEmpty()); + StringBuilder builder = new StringBuilder(language); + // Script tag. + String script = locale.getScript(); + if (!script.isEmpty()) { + builder.append('_'); + builder.append(script); + } + // 3-letter country code. + String country = locale.getISO3Country(); + if (!country.isEmpty()) { + builder.append('_'); + builder.append(country); + } + String localeName = builder.toString(); + // Verify locale names are unique. + assert (!ICULocaleMap.containsKey(localeName)); + ICULocaleMap.put(localeName, locale); + } + } + // Construct uppercase-normalized locale name mapping. + for (String localeName : ICULocaleMap.keySet()) { + String localeUppercase = localeName.toUpperCase(); + // Locale names are unique case-insensitively. + assert (!ICULocaleMapUppercase.containsKey(localeUppercase)); + ICULocaleMapUppercase.put(localeUppercase, localeName); + } + // Construct locale name to ID mapping. Locale ID is defined as index in `ICULocaleNames`. + ICULocaleNames = ICULocaleMap.keySet().toArray(new String[0]); + Arrays.sort(ICULocaleNames); + // Maximum number of locale IDs as defined by binary layout. + assert (ICULocaleNames.length <= (1 << 12)); + for (int i = 0; i < ICULocaleNames.length; ++i) { + ICULocaleToId.put(ICULocaleNames[i], i); + } + } + + private static final int UNICODE_COLLATION_ID = + new CollationSpecICU("UNICODE", CaseSensitivity.CS, AccentSensitivity.AS).collationId; + private static final int UNICODE_CI_COLLATION_ID = + new CollationSpecICU("UNICODE", CaseSensitivity.CI, AccentSensitivity.AS).collationId; + + private final CaseSensitivity caseSensitivity; + private final AccentSensitivity accentSensitivity; + private final String locale; + private final int collationId; + + private CollationSpecICU(String locale, CaseSensitivity caseSensitivity, + AccentSensitivity accentSensitivity) { + this.locale = locale; + this.caseSensitivity = caseSensitivity; + this.accentSensitivity = accentSensitivity; + // Construct collation ID from locale, case-sensitivity and accent-sensitivity specifiers. + int collationId = ICULocaleToId.get(locale); + // Mandatory ICU implementation provider. + collationId = SpecifierUtils.setSpecValue(collationId, IMPLEMENTATION_PROVIDER_OFFSET, + ImplementationProvider.ICU); + collationId = SpecifierUtils.setSpecValue(collationId, CASE_SENSITIVITY_OFFSET, + caseSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, ACCENT_SENSITIVITY_OFFSET, + accentSensitivity); + this.collationId = collationId; + } + + private static int collationNameToId( + String originalName, String collationName) throws SparkException { + // Search for the longest locale match because specifiers are designed to be different from + // script tag and country code, meaning the only valid locale name match can be the longest + // one. + int lastPos = -1; + for (int i = 1; i <= collationName.length(); i++) { + String localeName = collationName.substring(0, i); + if (ICULocaleMapUppercase.containsKey(localeName)) { + lastPos = i; + } + } + if (lastPos == -1) { + throw collationInvalidNameException(originalName); + } else { + String locale = collationName.substring(0, lastPos); + int collationId = ICULocaleToId.get(ICULocaleMapUppercase.get(locale)); + + // Try all combinations of AS/AI and CS/CI. + CaseSensitivity caseSensitivity; + AccentSensitivity accentSensitivity; + if (collationName.equals(locale) || + collationName.equals(locale + "_AS") || + collationName.equals(locale + "_CS") || + collationName.equals(locale + "_AS_CS") || + collationName.equals(locale + "_CS_AS") + ) { + caseSensitivity = CaseSensitivity.CS; + accentSensitivity = AccentSensitivity.AS; + } else if (collationName.equals(locale + "_CI") || + collationName.equals(locale + "_AS_CI") || + collationName.equals(locale + "_CI_AS")) { + caseSensitivity = CaseSensitivity.CI; + accentSensitivity = AccentSensitivity.AS; + } else if (collationName.equals(locale + "_AI") || + collationName.equals(locale + "_CS_AI") || + collationName.equals(locale + "_AI_CS")) { + caseSensitivity = CaseSensitivity.CS; + accentSensitivity = AccentSensitivity.AI; + } else if (collationName.equals(locale + "_AI_CI") || + collationName.equals(locale + "_CI_AI")) { + caseSensitivity = CaseSensitivity.CI; + accentSensitivity = AccentSensitivity.AI; + } else { + throw collationInvalidNameException(originalName); + } + + // Build collation ID from computed specifiers. + collationId = SpecifierUtils.setSpecValue(collationId, + IMPLEMENTATION_PROVIDER_OFFSET, ImplementationProvider.ICU); + collationId = SpecifierUtils.setSpecValue(collationId, + CASE_SENSITIVITY_OFFSET, caseSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, + ACCENT_SENSITIVITY_OFFSET, accentSensitivity); + return collationId; + } + } + + private static CollationSpecICU fromCollationId(int collationId) { + // Parse specifiers from collation ID. + int caseSensitivityOrdinal = SpecifierUtils.getSpecValue(collationId, + CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); + int accentSensitivityOrdinal = SpecifierUtils.getSpecValue(collationId, + ACCENT_SENSITIVITY_OFFSET, ACCENT_SENSITIVITY_MASK); + collationId = SpecifierUtils.removeSpec(collationId, + IMPLEMENTATION_PROVIDER_OFFSET, IMPLEMENTATION_PROVIDER_MASK); + collationId = SpecifierUtils.removeSpec(collationId, + CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); + collationId = SpecifierUtils.removeSpec(collationId, + ACCENT_SENSITIVITY_OFFSET, ACCENT_SENSITIVITY_MASK); + // Locale ID remains after removing all other specifiers. + int localeId = collationId; + // Verify locale ID is valid against `ICULocaleNames` array. + assert (localeId < ICULocaleNames.length); + CaseSensitivity caseSensitivity = CaseSensitivity.values()[caseSensitivityOrdinal]; + AccentSensitivity accentSensitivity = AccentSensitivity.values()[accentSensitivityOrdinal]; + String locale = ICULocaleNames[localeId]; + return new CollationSpecICU(locale, caseSensitivity, accentSensitivity); + } + + @Override + protected Collation buildCollation() { + ULocale.Builder builder = new ULocale.Builder(); + builder.setLocale(ICULocaleMap.get(locale)); + // Compute unicode locale keyword for all combinations of case/accent sensitivity. + if (caseSensitivity == CaseSensitivity.CS && + accentSensitivity == AccentSensitivity.AS) { + builder.setUnicodeLocaleKeyword("ks", "level3"); + } else if (caseSensitivity == CaseSensitivity.CS && + accentSensitivity == AccentSensitivity.AI) { + builder + .setUnicodeLocaleKeyword("ks", "level1") + .setUnicodeLocaleKeyword("kc", "true"); + } else if (caseSensitivity == CaseSensitivity.CI && + accentSensitivity == AccentSensitivity.AS) { + builder.setUnicodeLocaleKeyword("ks", "level2"); + } else if (caseSensitivity == CaseSensitivity.CI && + accentSensitivity == AccentSensitivity.AI) { + builder.setUnicodeLocaleKeyword("ks", "level1"); + } + ULocale resultLocale = builder.build(); + Collator collator = Collator.getInstance(resultLocale); + // Freeze ICU collator to ensure thread safety. + collator.freeze(); + return new Collation( + collationName(), + PROVIDER_ICU, + collator, + (s1, s2) -> collator.compare(s1.toString(), s2.toString()), + ICU_COLLATOR_VERSION, + s -> (long) collator.getCollationKey(s.toString()).hashCode(), + /* supportsBinaryEquality = */ collationId == UNICODE_COLLATION_ID, + /* supportsBinaryOrdering = */ false, + /* supportsLowercaseEquality = */ false); + } + + /** + * Compute normalized collation name. Components of collation name are given in order: + * - Locale name + * - Optional case sensitivity when non-default preceded by underscore + * - Optional accent sensitivity when non-default preceded by underscore + * Examples: en, en_USA_CI_AI, sr_Cyrl_SRB_AI. + */ + private String collationName() { + StringBuilder builder = new StringBuilder(); + builder.append(locale); + if (caseSensitivity != CaseSensitivity.CS) { + builder.append('_'); + builder.append(caseSensitivity.toString()); + } + if (accentSensitivity != AccentSensitivity.AS) { + builder.append('_'); + builder.append(accentSensitivity.toString()); + } + return builder.toString(); + } + } + + /** + * Utility class for manipulating conversions between collation IDs and specifier enums/locale + * IDs. Scope bitwise operations here to avoid confusion. + */ + private static class SpecifierUtils { + private static int getSpecValue(int collationId, int offset, int mask) { + return (collationId >> offset) & mask; + } + + private static int removeSpec(int collationId, int offset, int mask) { + return collationId & ~(mask << offset); + } + + private static int setSpecValue(int collationId, int offset, Enum spec) { + return collationId | (spec.ordinal() << offset); + } } /** Returns the collation identifier. */ @@ -201,75 +722,20 @@ public final class CollationFactory { } } - private static final Collation[] collationTable = new Collation[4]; - private static final HashMap<String, Integer> collationNameToIdMap = new HashMap<>(); - - public static final int UTF8_BINARY_COLLATION_ID = 0; - public static final int UTF8_BINARY_LCASE_COLLATION_ID = 1; - public static final String PROVIDER_SPARK = "spark"; public static final String PROVIDER_ICU = "icu"; public static final List<String> SUPPORTED_PROVIDERS = List.of(PROVIDER_SPARK, PROVIDER_ICU); - static { - // Binary comparison. This is the default collation. - // No custom comparators will be used for this collation. - // Instead, we rely on byte for byte comparison. - collationTable[0] = new Collation( - "UTF8_BINARY", - PROVIDER_SPARK, - null, - UTF8String::binaryCompare, - "1.0", - s -> (long)s.hashCode(), - true, - true, - false); - - // Case-insensitive UTF8 binary collation. - // TODO: Do in place comparisons instead of creating new strings. - collationTable[1] = new Collation( - "UTF8_BINARY_LCASE", - PROVIDER_SPARK, - null, - UTF8String::compareLowerCase, - "1.0", - (s) -> (long)s.toLowerCase().hashCode(), - false, - false, - true); - - // UNICODE case sensitive comparison (ROOT locale, in ICU). - collationTable[2] = new Collation( - "UNICODE", - PROVIDER_ICU, - Collator.getInstance(ULocale.ROOT), - "153.120.0.0", - true, - false, - false - ); - - collationTable[2].collator.setStrength(Collator.TERTIARY); - collationTable[2].collator.freeze(); - - // UNICODE case-insensitive comparison (ROOT locale, in ICU + Secondary strength). - collationTable[3] = new Collation( - "UNICODE_CI", - PROVIDER_ICU, - Collator.getInstance(ULocale.ROOT), - "153.120.0.0", - false, - false, - false - ); - collationTable[3].collator.setStrength(Collator.SECONDARY); - collationTable[3].collator.freeze(); - - for (int i = 0; i < collationTable.length; i++) { - collationNameToIdMap.put(collationTable[i].collationName, i); - } - } + public static final int UTF8_BINARY_COLLATION_ID = + Collation.CollationSpecUTF8Binary.UTF8_BINARY_COLLATION_ID; + public static final int UTF8_BINARY_LCASE_COLLATION_ID = + Collation.CollationSpecUTF8Binary.UTF8_BINARY_LCASE_COLLATION_ID; + public static final int UNICODE_COLLATION_ID = + Collation.CollationSpecICU.UNICODE_COLLATION_ID; + public static final int UNICODE_CI_COLLATION_ID = + Collation.CollationSpecICU.UNICODE_CI_COLLATION_ID; + public static final int INDETERMINATE_COLLATION_ID = + Collation.CollationSpec.INDETERMINATE_COLLATION_ID; /** * Returns a StringSearch object for the given pattern and target strings, under collation @@ -297,23 +763,6 @@ public final class CollationFactory { return new StringSearch(patternString, target, (RuleBasedCollator) collator); } - /** - * Returns if the given collationName is valid one. - */ - public static boolean isValidCollation(String collationName) { - return collationNameToIdMap.containsKey(collationName.toUpperCase()); - } - - /** - * Returns closest valid name to collationName - */ - public static String getClosestCollation(String collationName) { - Collation suggestion = Collections.min(List.of(collationTable), Comparator.comparingInt( - c -> UTF8String.fromString(c.collationName).levenshteinDistance( - UTF8String.fromString(collationName.toUpperCase())))); - return suggestion.collationName; - } - /** * Returns a collation-unaware StringSearch object for the given pattern and target strings. * While this object does not respect collation, it can be used to find occurrences of the pattern @@ -326,24 +775,10 @@ public final class CollationFactory { } /** - * Returns the collation id for the given collation name. + * Returns the collation ID for the given collation name. */ public static int collationNameToId(String collationName) throws SparkException { - String normalizedName = collationName.toUpperCase(); - if (collationNameToIdMap.containsKey(normalizedName)) { - return collationNameToIdMap.get(normalizedName); - } else { - Collation suggestion = Collections.min(List.of(collationTable), Comparator.comparingInt( - c -> UTF8String.fromString(c.collationName).levenshteinDistance( - UTF8String.fromString(normalizedName)))); - - Map<String, String> params = new HashMap<>(); - params.put("collationName", collationName); - params.put("proposal", suggestion.collationName); - - throw new SparkException( - "COLLATION_INVALID_NAME", SparkException.constructMessageParams(params), null); - } + return Collation.CollationSpec.collationNameToId(collationName); } public static void assertValidProvider(String provider) throws SparkException { @@ -359,12 +794,15 @@ public final class CollationFactory { } public static Collation fetchCollation(int collationId) { - return collationTable[collationId]; + return Collation.CollationSpec.fetchCollation(collationId); } public static Collation fetchCollation(String collationName) throws SparkException { - int collationId = collationNameToId(collationName); - return collationTable[collationId]; + return fetchCollation(collationNameToId(collationName)); + } + + public static String[] getICULocaleNames() { + return Collation.CollationSpecICU.ICULocaleNames; } public static UTF8String getCollationKey(UTF8String input, int collationId) { diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 768d26bf0e11..69104dea0e99 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -20,7 +20,10 @@ package org.apache.spark.unsafe.types import scala.collection.parallel.immutable.ParSeq import scala.jdk.CollectionConverters.MapHasAsScala +import com.ibm.icu.util.ULocale + import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.util.CollationFactory.fetchCollation // scalastyle:off import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.must.Matchers @@ -30,31 +33,95 @@ import org.apache.spark.unsafe.types.UTF8String.{fromString => toUTF8} class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ignore funsuite test("collationId stability") { - val utf8Binary = fetchCollation(0) + assert(INDETERMINATE_COLLATION_ID == -1) + + assert(UTF8_BINARY_COLLATION_ID == 0) + val utf8Binary = fetchCollation(UTF8_BINARY_COLLATION_ID) assert(utf8Binary.collationName == "UTF8_BINARY") assert(utf8Binary.supportsBinaryEquality) - val utf8BinaryLcase = fetchCollation(1) + assert(UTF8_BINARY_LCASE_COLLATION_ID == 1) + val utf8BinaryLcase = fetchCollation(UTF8_BINARY_LCASE_COLLATION_ID) assert(utf8BinaryLcase.collationName == "UTF8_BINARY_LCASE") assert(!utf8BinaryLcase.supportsBinaryEquality) - val unicode = fetchCollation(2) + assert(UNICODE_COLLATION_ID == (1 << 29)) + val unicode = fetchCollation(UNICODE_COLLATION_ID) assert(unicode.collationName == "UNICODE") - assert(unicode.supportsBinaryEquality); + assert(unicode.supportsBinaryEquality) - val unicodeCi = fetchCollation(3) + assert(UNICODE_CI_COLLATION_ID == ((1 << 29) | (1 << 17))) + val unicodeCi = fetchCollation(UNICODE_CI_COLLATION_ID) assert(unicodeCi.collationName == "UNICODE_CI") assert(!unicodeCi.supportsBinaryEquality) } - test("fetch invalid collation name") { - val error = intercept[SparkException] { - fetchCollation("UTF8_BS") + test("UTF8_BINARY and ICU root locale collation names") { + // Collation name already normalized. + Seq( + "UTF8_BINARY", + "UTF8_BINARY_LCASE", + "UNICODE", + "UNICODE_CI", + "UNICODE_AI", + "UNICODE_CI_AI" + ).foreach(collationName => { + val col = fetchCollation(collationName) + assert(col.collationName == collationName) + }) + // Collation name normalization. + Seq( + // ICU root locale. + ("UNICODE_CS", "UNICODE"), + ("UNICODE_CS_AS", "UNICODE"), + ("UNICODE_CI_AS", "UNICODE_CI"), + ("UNICODE_AI_CS", "UNICODE_AI"), + ("UNICODE_AI_CI", "UNICODE_CI_AI"), + // Randomized case collation names. + ("utf8_binary", "UTF8_BINARY"), + ("UtF8_binARy_LcasE", "UTF8_BINARY_LCASE"), + ("unicode", "UNICODE"), + ("UnICoDe_cs_aI", "UNICODE_AI") + ).foreach{ + case (name, normalized) => + val col = fetchCollation(name) + assert(col.collationName == normalized) } + } + + test("fetch invalid UTF8_BINARY and ICU root locale collation names") { + Seq( + "UTF8_BINARY_CS", + "UTF8_BINARY_AS", + "UTF8_BINARY_CS_AS", + "UTF8_BINARY_AS_CS", + "UTF8_BINARY_CI", + "UTF8_BINARY_AI", + "UTF8_BINARY_CI_AI", + "UTF8_BINARY_AI_CI", + "UTF8_BS", + "BINARY_UTF8", + "UTF8_BINARY_A", + "UNICODE_X", + "UNICODE_CI_X", + "UNICODE_LCASE_X", + "UTF8_UNICODE", + "UTF8_BINARY_UNICODE", + "CI_UNICODE", + "LCASE_UNICODE", + "UNICODE_UNSPECIFIED", + "UNICODE_CI_UNSPECIFIED", + "UNICODE_UNSPECIFIED_CI_UNSPECIFIED", + "UNICODE_INDETERMINATE", + "UNICODE_CI_INDETERMINATE" + ).foreach(collationName => { + val error = intercept[SparkException] { + fetchCollation(collationName) + } - assert(error.getErrorClass === "COLLATION_INVALID_NAME") - assert(error.getMessageParameters.asScala === - Map("proposal" -> "UTF8_BINARY", "collationName" -> "UTF8_BS")) + assert(error.getErrorClass === "COLLATION_INVALID_NAME") + assert(error.getMessageParameters.asScala === Map("collationName" -> collationName)) + }) } case class CollationTestCase[R](collationName: String, s1: String, s2: String, expectedResult: R) @@ -152,4 +219,238 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig } }) } + + test("test collation caching") { + Seq( + "UTF8_BINARY", + "UTF8_BINARY_LCASE", + "UNICODE", + "UNICODE_CI", + "UNICODE_AI", + "UNICODE_CI_AI", + "UNICODE_AI_CI" + ).foreach(collationId => { + val col1 = fetchCollation(collationId) + val col2 = fetchCollation(collationId) + assert(col1 eq col2) // Check for reference equality. + }) + } + + test("collations with ICU non-root localization") { + Seq( + // Language only. + "en", + "en_CS", + "en_CI", + "en_AS", + "en_AI", + // Language + 3-letter country code. + "en_USA", + "en_USA_CS", + "en_USA_CI", + "en_USA_AS", + "en_USA_AI", + // Language + script code. + "sr_Cyrl", + "sr_Cyrl_CS", + "sr_Cyrl_CI", + "sr_Cyrl_AS", + "sr_Cyrl_AI", + // Language + script code + 3-letter country code. + "sr_Cyrl_SRB", + "sr_Cyrl_SRB_CS", + "sr_Cyrl_SRB_CI", + "sr_Cyrl_SRB_AS", + "sr_Cyrl_SRB_AI" + ).foreach(collationICU => { + val col = fetchCollation(collationICU) + assert(col.collator.getLocale(ULocale.VALID_LOCALE) != ULocale.ROOT) + }) + } + + test("invalid names of collations with ICU non-root localization") { + Seq( + "en_US", // Must use 3-letter country code + "enn", + "en_AAA", + "en_Something", + "en_Something_USA", + "en_LCASE", + "en_UCASE", + "en_CI_LCASE", + "en_CI_UCASE", + "en_CI_UNSPECIFIED", + "en_USA_UNSPECIFIED", + "en_USA_UNSPECIFIED_CI", + "en_INDETERMINATE", + "en_USA_INDETERMINATE", + "en_Latn_USA", // Use en_USA instead. + "en_Cyrl_USA", + "en_USA_AAA", + "sr_Cyrl_SRB_AAA", + // Invalid ordering of language, script and country code. + "USA_en", + "sr_SRB_Cyrl", + "SRB_sr", + "SRB_sr_Cyrl", + "SRB_Cyrl_sr", + "Cyrl_sr", + "Cyrl_sr_SRB", + "Cyrl_SRB_sr", + // Collation specifiers in the middle of locale. + "CI_en", + "USA_CI_en", + "en_CI_USA", + "CI_sr_Cyrl_SRB", + "sr_CI_Cyrl_SRB", + "sr_Cyrl_CI_SRB", + "CI_Cyrl_sr", + "Cyrl_CI_sr", + "Cyrl_CI_sr_SRB", + "Cyrl_sr_CI_SRB" + ).foreach(collationName => { + val error = intercept[SparkException] { + fetchCollation(collationName) + } + + assert(error.getErrorClass === "COLLATION_INVALID_NAME") + assert(error.getMessageParameters.asScala === Map("collationName" -> collationName)) + }) + } + + test("collations name normalization for ICU non-root localization") { + Seq( + ("en_USA", "en_USA"), + ("en_CS", "en"), + ("en_AS", "en"), + ("en_CS_AS", "en"), + ("en_AS_CS", "en"), + ("en_CI", "en_CI"), + ("en_AI", "en_AI"), + ("en_AI_CI", "en_CI_AI"), + ("en_CI_AI", "en_CI_AI"), + ("en_CS_AI", "en_AI"), + ("en_AI_CS", "en_AI"), + ("en_CI_AS", "en_CI"), + ("en_AS_CI", "en_CI"), + ("en_USA_AI_CI", "en_USA_CI_AI"), + // Randomized case. + ("EN_USA", "en_USA"), + ("SR_CYRL", "sr_Cyrl"), + ("sr_cyrl_srb", "sr_Cyrl_SRB"), + ("sR_cYRl_sRb", "sr_Cyrl_SRB") + ).foreach { + case (name, normalized) => + val col = fetchCollation(name) + assert(col.collationName == normalized) + } + } + + test("invalid collationId") { + val badCollationIds = Seq( + INDETERMINATE_COLLATION_ID, // Indeterminate collation. + 1 << 30, // User-defined collation range. + (1 << 30) | 1, // User-defined collation range. + (1 << 30) | (1 << 29), // User-defined collation range. + 1 << 1, // UTF8_BINARY mandatory zero bit 1 breach. + 1 << 2, // UTF8_BINARY mandatory zero bit 2 breach. + 1 << 3, // UTF8_BINARY mandatory zero bit 3 breach. + 1 << 4, // UTF8_BINARY mandatory zero bit 4 breach. + 1 << 5, // UTF8_BINARY mandatory zero bit 5 breach. + 1 << 6, // UTF8_BINARY mandatory zero bit 6 breach. + 1 << 7, // UTF8_BINARY mandatory zero bit 7 breach. + 1 << 8, // UTF8_BINARY mandatory zero bit 8 breach. + 1 << 9, // UTF8_BINARY mandatory zero bit 9 breach. + 1 << 10, // UTF8_BINARY mandatory zero bit 10 breach. + 1 << 11, // UTF8_BINARY mandatory zero bit 11 breach. + 1 << 12, // UTF8_BINARY mandatory zero bit 12 breach. + 1 << 13, // UTF8_BINARY mandatory zero bit 13 breach. + 1 << 14, // UTF8_BINARY mandatory zero bit 14 breach. + 1 << 15, // UTF8_BINARY mandatory zero bit 15 breach. + 1 << 16, // UTF8_BINARY mandatory zero bit 16 breach. + 1 << 17, // UTF8_BINARY mandatory zero bit 17 breach. + 1 << 18, // UTF8_BINARY mandatory zero bit 18 breach. + 1 << 19, // UTF8_BINARY mandatory zero bit 19 breach. + 1 << 20, // UTF8_BINARY mandatory zero bit 20 breach. + 1 << 23, // UTF8_BINARY mandatory zero bit 23 breach. + 1 << 24, // UTF8_BINARY mandatory zero bit 24 breach. + 1 << 25, // UTF8_BINARY mandatory zero bit 25 breach. + 1 << 26, // UTF8_BINARY mandatory zero bit 26 breach. + 1 << 27, // UTF8_BINARY mandatory zero bit 27 breach. + 1 << 28, // UTF8_BINARY mandatory zero bit 28 breach. + (1 << 29) | (1 << 12), // ICU mandatory zero bit 12 breach. + (1 << 29) | (1 << 13), // ICU mandatory zero bit 13 breach. + (1 << 29) | (1 << 14), // ICU mandatory zero bit 14 breach. + (1 << 29) | (1 << 15), // ICU mandatory zero bit 15 breach. + (1 << 29) | (1 << 18), // ICU mandatory zero bit 18 breach. + (1 << 29) | (1 << 19), // ICU mandatory zero bit 19 breach. + (1 << 29) | (1 << 20), // ICU mandatory zero bit 20 breach. + (1 << 29) | (1 << 21), // ICU mandatory zero bit 21 breach. + (1 << 29) | (1 << 22), // ICU mandatory zero bit 22 breach. + (1 << 29) | (1 << 23), // ICU mandatory zero bit 23 breach. + (1 << 29) | (1 << 24), // ICU mandatory zero bit 24 breach. + (1 << 29) | (1 << 25), // ICU mandatory zero bit 25 breach. + (1 << 29) | (1 << 26), // ICU mandatory zero bit 26 breach. + (1 << 29) | (1 << 27), // ICU mandatory zero bit 27 breach. + (1 << 29) | (1 << 28), // ICU mandatory zero bit 28 breach. + (1 << 29) | 0xFFFF // ICU with invalid locale id. + ) + badCollationIds.foreach(collationId => { + // Assumptions about collation id will break and assert statement will fail. + intercept[AssertionError](fetchCollation(collationId)) + }) + } + + test("repeated and/or incompatible specifiers in collation name") { + Seq( + "UTF8_BINARY_LCASE_LCASE", + "UNICODE_CS_CS", + "UNICODE_CI_CI", + "UNICODE_CI_CS", + "UNICODE_CS_CI", + "UNICODE_AS_AS", + "UNICODE_AI_AI", + "UNICODE_AS_AI", + "UNICODE_AI_AS", + "UNICODE_AS_CS_AI", + "UNICODE_CS_AI_CI", + "UNICODE_CS_AS_CI_AI" + ).foreach(collationName => { + val error = intercept[SparkException] { + fetchCollation(collationName) + } + + assert(error.getErrorClass === "COLLATION_INVALID_NAME") + assert(error.getMessageParameters.asScala === Map("collationName" -> collationName)) + }) + } + + test("basic ICU collator checks") { + Seq( + CollationTestCase("UNICODE_CI", "a", "A", true), + CollationTestCase("UNICODE_CI", "a", "å", false), + CollationTestCase("UNICODE_CI", "a", "Å", false), + CollationTestCase("UNICODE_AI", "a", "A", false), + CollationTestCase("UNICODE_AI", "a", "å", true), + CollationTestCase("UNICODE_AI", "a", "Å", false), + CollationTestCase("UNICODE_CI_AI", "a", "A", true), + CollationTestCase("UNICODE_CI_AI", "a", "å", true), + CollationTestCase("UNICODE_CI_AI", "a", "Å", true) + ).foreach(testCase => { + val collation = fetchCollation(testCase.collationName) + assert(collation.equalsFunction(toUTF8(testCase.s1), toUTF8(testCase.s2)) == + testCase.expectedResult) + }) + Seq( + CollationTestCase("en", "a", "A", -1), + CollationTestCase("en_CI", "a", "A", 0), + CollationTestCase("en_AI", "a", "å", 0), + CollationTestCase("sv", "Kypper", "Köpfe", -1), + CollationTestCase("de", "Kypper", "Köpfe", 1) + ).foreach(testCase => { + val collation = fetchCollation(testCase.collationName) + val result = collation.comparator.compare(toUTF8(testCase.s1), toUTF8(testCase.s2)) + assert(Integer.signum(result) == testCase.expectedResult) + }) + } } diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 883c51bffade..b19b05859f78 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -469,7 +469,7 @@ }, "COLLATION_INVALID_NAME" : { "message" : [ - "The value <collationName> does not represent a correct collation name. Suggested valid collation name: [<proposal>]." + "The value <collationName> does not represent a correct collation name." ], "sqlState" : "42704" }, @@ -1921,7 +1921,7 @@ "subClass" : { "DEFAULT_COLLATION" : { "message" : [ - "Cannot resolve the given default collation. Did you mean '<proposal>'?" + "Cannot resolve the given default collation." ] }, "TIME_ZONE" : { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 49b1a5312fda..e0ad8f7078ca 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.{functions => fn} import org.apache.spark.sql.avro.{functions => avroFn} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -699,7 +700,8 @@ class PlanGenerationTestSuite } test("select collated string") { - val schema = StructType(StructField("s", StringType(1)) :: Nil) + val schema = StructType( + StructField("s", StringType(CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID)) :: Nil) createLocalRelation(schema.catalogString).select("s") } diff --git a/connector/connect/common/src/main/protobuf/spark/connect/types.proto b/connector/connect/common/src/main/protobuf/spark/connect/types.proto index 48f7385330c8..4f768f201575 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/types.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/types.proto @@ -101,7 +101,7 @@ message DataType { message String { uint32 type_variation_reference = 1; - uint32 collation_id = 2; + string collation = 2; } message Binary { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index 1f580a0ffc0a..f63692717947 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.common import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkClassUtils @@ -80,7 +81,7 @@ object DataTypeProtoConverter { } private def toCatalystStringType(t: proto.DataType.String): StringType = - StringType(t.getCollationId) + StringType(if (t.getCollation.nonEmpty) t.getCollation else "UTF8_BINARY") private def toCatalystYearMonthIntervalType(t: proto.DataType.YearMonthInterval) = { (t.hasStartField, t.hasEndField) match { @@ -177,7 +178,11 @@ object DataTypeProtoConverter { case s: StringType => proto.DataType .newBuilder() - .setString(proto.DataType.String.newBuilder().setCollationId(s.collationId).build()) + .setString( + proto.DataType.String + .newBuilder() + .setCollation(CollationFactory.fetchCollation(s.collationId).collationName) + .build()) .build() case CharType(length) => diff --git a/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.json b/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.json index 33f6007ec68a..e4b31258f984 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.json @@ -18,7 +18,7 @@ "name": "c1", "dataType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "nullable": true diff --git a/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.proto.bin index da4ad9bf9a4e..c39243a10a8e 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/csv_from_dataset.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json index adf8cabd97b1..2a5a0ddd15f8 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json @@ -305,7 +305,7 @@ "array": { "elementType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "elements": [{ @@ -324,7 +324,7 @@ "array": { "elementType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "elements": [{ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin index d8b4407f6cfa..359ddd61d8b7 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.json index 1e651f0455c7..aaf3a91c4fe1 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.json @@ -200,7 +200,7 @@ "map": { "keyType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "valueType": { @@ -228,7 +228,7 @@ "name": "_1", "dataType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "nullable": true @@ -404,7 +404,7 @@ "map": { "keyType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "valueType": { @@ -417,7 +417,7 @@ "map": { "keyType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "valueType": { @@ -439,7 +439,7 @@ "map": { "keyType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "valueType": { @@ -461,7 +461,7 @@ "map": { "keyType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "valueType": { @@ -493,7 +493,7 @@ "map": { "keyType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "valueType": { @@ -511,7 +511,7 @@ "map": { "keyType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "valueType": { @@ -533,7 +533,7 @@ "map": { "keyType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "valueType": { @@ -576,7 +576,7 @@ "map": { "keyType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "valueType": { @@ -594,7 +594,7 @@ "name": "_1", "dataType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "nullable": true @@ -608,7 +608,7 @@ }, "valueType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "valueContainsNull": true @@ -640,7 +640,7 @@ "map": { "keyType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "valueType": { @@ -666,7 +666,7 @@ "name": "_1", "dataType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "nullable": true @@ -680,7 +680,7 @@ }, "valueType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "valueContainsNull": true @@ -700,7 +700,7 @@ }, "valueType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "keys": [{ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin index b3f61830bee0..71640717c12e 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.json b/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.json index 537c218952a4..f29245374e6e 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.json @@ -18,7 +18,7 @@ "name": "c1", "dataType": { "string": { - "collationId": 0 + "collation": "UTF8_BINARY" } }, "nullable": true diff --git a/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.proto.bin index 297ab2bf0262..1ce2e676ce30 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/json_from_dataset.proto.bin differ diff --git a/python/pyspark/sql/connect/proto/types_pb2.py b/python/pyspark/sql/connect/proto/types_pb2.py index 65e5860b5dc6..1022605fb160 100644 --- a/python/pyspark/sql/connect/proto/types_pb2.py +++ b/python/pyspark/sql/connect/proto/types_pb2.py @@ -29,7 +29,7 @@ _sym_db = _symbol_database.Default() DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b"\n\x19spark/connect/types.proto\x12\rspark.connect\"\xec!\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x [...] + b"\n\x19spark/connect/types.proto\x12\rspark.connect\"\xe7!\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -42,7 +42,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated" ) _DATATYPE._serialized_start = 45 - _DATATYPE._serialized_end = 4377 + _DATATYPE._serialized_end = 4372 _DATATYPE_BOOLEAN._serialized_start = 1595 _DATATYPE_BOOLEAN._serialized_end = 1662 _DATATYPE_BYTE._serialized_start = 1664 @@ -58,41 +58,41 @@ if _descriptor._USE_C_DESCRIPTORS == False: _DATATYPE_DOUBLE._serialized_start = 1999 _DATATYPE_DOUBLE._serialized_end = 2065 _DATATYPE_STRING._serialized_start = 2067 - _DATATYPE_STRING._serialized_end = 2168 - _DATATYPE_BINARY._serialized_start = 2170 - _DATATYPE_BINARY._serialized_end = 2236 - _DATATYPE_NULL._serialized_start = 2238 - _DATATYPE_NULL._serialized_end = 2302 - _DATATYPE_TIMESTAMP._serialized_start = 2304 - _DATATYPE_TIMESTAMP._serialized_end = 2373 - _DATATYPE_DATE._serialized_start = 2375 - _DATATYPE_DATE._serialized_end = 2439 - _DATATYPE_TIMESTAMPNTZ._serialized_start = 2441 - _DATATYPE_TIMESTAMPNTZ._serialized_end = 2513 - _DATATYPE_CALENDARINTERVAL._serialized_start = 2515 - _DATATYPE_CALENDARINTERVAL._serialized_end = 2591 - _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2594 - _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2773 - _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2776 - _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2953 - _DATATYPE_CHAR._serialized_start = 2955 - _DATATYPE_CHAR._serialized_end = 3043 - _DATATYPE_VARCHAR._serialized_start = 3045 - _DATATYPE_VARCHAR._serialized_end = 3136 - _DATATYPE_DECIMAL._serialized_start = 3139 - _DATATYPE_DECIMAL._serialized_end = 3292 - _DATATYPE_STRUCTFIELD._serialized_start = 3295 - _DATATYPE_STRUCTFIELD._serialized_end = 3456 - _DATATYPE_STRUCT._serialized_start = 3458 - _DATATYPE_STRUCT._serialized_end = 3585 - _DATATYPE_ARRAY._serialized_start = 3588 - _DATATYPE_ARRAY._serialized_end = 3750 - _DATATYPE_MAP._serialized_start = 3753 - _DATATYPE_MAP._serialized_end = 3972 - _DATATYPE_VARIANT._serialized_start = 3974 - _DATATYPE_VARIANT._serialized_end = 4041 - _DATATYPE_UDT._serialized_start = 4044 - _DATATYPE_UDT._serialized_end = 4315 - _DATATYPE_UNPARSED._serialized_start = 4317 - _DATATYPE_UNPARSED._serialized_end = 4369 + _DATATYPE_STRING._serialized_end = 2163 + _DATATYPE_BINARY._serialized_start = 2165 + _DATATYPE_BINARY._serialized_end = 2231 + _DATATYPE_NULL._serialized_start = 2233 + _DATATYPE_NULL._serialized_end = 2297 + _DATATYPE_TIMESTAMP._serialized_start = 2299 + _DATATYPE_TIMESTAMP._serialized_end = 2368 + _DATATYPE_DATE._serialized_start = 2370 + _DATATYPE_DATE._serialized_end = 2434 + _DATATYPE_TIMESTAMPNTZ._serialized_start = 2436 + _DATATYPE_TIMESTAMPNTZ._serialized_end = 2508 + _DATATYPE_CALENDARINTERVAL._serialized_start = 2510 + _DATATYPE_CALENDARINTERVAL._serialized_end = 2586 + _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2589 + _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2768 + _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2771 + _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2948 + _DATATYPE_CHAR._serialized_start = 2950 + _DATATYPE_CHAR._serialized_end = 3038 + _DATATYPE_VARCHAR._serialized_start = 3040 + _DATATYPE_VARCHAR._serialized_end = 3131 + _DATATYPE_DECIMAL._serialized_start = 3134 + _DATATYPE_DECIMAL._serialized_end = 3287 + _DATATYPE_STRUCTFIELD._serialized_start = 3290 + _DATATYPE_STRUCTFIELD._serialized_end = 3451 + _DATATYPE_STRUCT._serialized_start = 3453 + _DATATYPE_STRUCT._serialized_end = 3580 + _DATATYPE_ARRAY._serialized_start = 3583 + _DATATYPE_ARRAY._serialized_end = 3745 + _DATATYPE_MAP._serialized_start = 3748 + _DATATYPE_MAP._serialized_end = 3967 + _DATATYPE_VARIANT._serialized_start = 3969 + _DATATYPE_VARIANT._serialized_end = 4036 + _DATATYPE_UDT._serialized_start = 4039 + _DATATYPE_UDT._serialized_end = 4310 + _DATATYPE_UNPARSED._serialized_start = 4312 + _DATATYPE_UNPARSED._serialized_end = 4364 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi b/python/pyspark/sql/connect/proto/types_pb2.pyi index e6b34d3485c2..b37621104537 100644 --- a/python/pyspark/sql/connect/proto/types_pb2.pyi +++ b/python/pyspark/sql/connect/proto/types_pb2.pyi @@ -178,22 +178,19 @@ class DataType(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int - COLLATION_ID_FIELD_NUMBER: builtins.int + COLLATION_FIELD_NUMBER: builtins.int type_variation_reference: builtins.int - collation_id: builtins.int + collation: builtins.str def __init__( self, *, type_variation_reference: builtins.int = ..., - collation_id: builtins.int = ..., + collation: builtins.str = ..., ) -> None: ... def ClearField( self, field_name: typing_extensions.Literal[ - "collation_id", - b"collation_id", - "type_variation_reference", - b"type_variation_reference", + "collation", b"collation", "type_variation_reference", b"type_variation_reference" ], ) -> None: ... diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index 351fa0165965..885ce62e7db6 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -129,7 +129,7 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: if isinstance(data_type, NullType): ret.null.CopyFrom(pb2.DataType.NULL()) elif isinstance(data_type, StringType): - ret.string.collation_id = data_type.collationId + ret.string.collation = data_type.collation elif isinstance(data_type, BooleanType): ret.boolean.CopyFrom(pb2.DataType.Boolean()) elif isinstance(data_type, BinaryType): @@ -229,7 +229,8 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: s = schema.decimal.scale if schema.decimal.HasField("scale") else 0 return DecimalType(precision=p, scale=s) elif schema.HasField("string"): - return StringType.fromCollationId(schema.string.collation_id) + collation = schema.string.collation if schema.string.collation != "" else "UTF8_BINARY" + return StringType(collation) elif schema.HasField("char"): return CharType(schema.char.length) elif schema.HasField("var_char"): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 563c63f5dfb1..c72ff72ce426 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -280,26 +280,13 @@ class StringType(AtomicType): name of the collation, default is UTF8_BINARY. """ - collationNames = ["UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI"] providerSpark = "spark" providerICU = "icu" providers = [providerSpark, providerICU] - def __init__(self, collation: Optional[str] = None): + def __init__(self, collation: str = "UTF8_BINARY"): self.typeName = self._type_name # type: ignore[method-assign] - self.collationId = 0 if collation is None else self.collationNameToId(collation) - - @classmethod - def fromCollationId(self, collationId: int) -> "StringType": - return StringType(StringType.collationNames[collationId]) - - @classmethod - def collationIdToName(cls, collationId: int) -> str: - return StringType.collationNames[collationId] - - @classmethod - def collationNameToId(cls, collationName: str) -> int: - return StringType.collationNames.index(collationName) + self.collation = collation @classmethod def collationProvider(cls, collationName: str) -> str: @@ -312,7 +299,7 @@ class StringType(AtomicType): if self.isUTF8BinaryCollation(): return "string" - return f"string collate ${self.collationIdToName(self.collationId)}" + return f"string collate ${self.collation}" # For backwards compatibility and compatibility with other readers all string types # are serialized in json as regular strings and the collation info is written to @@ -322,13 +309,11 @@ class StringType(AtomicType): def __repr__(self) -> str: return ( - "StringType('%s')" % StringType.collationNames[self.collationId] - if self.collationId != 0 - else "StringType()" + "StringType()" if self.isUTF8BinaryCollation() else "StringType('%s')" % self.collation ) def isUTF8BinaryCollation(self) -> bool: - return self.collationId == 0 + return self.collation == "UTF8_BINARY" class CharType(AtomicType): @@ -1046,7 +1031,7 @@ class StructField(DataType): def schemaCollationValue(self, dt: DataType) -> str: assert isinstance(dt, StringType) - collationName = StringType.collationIdToName(dt.collationId) + collationName = dt.collation provider = StringType.collationProvider(collationName) return f"{provider}.{collationName}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 06e0c6eda589..f6f5b23b7f10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -772,12 +772,17 @@ object SQLConf { " produced by a builtin function such as to_char or CAST") .version("4.0.0") .stringConf - .checkValue(CollationFactory.isValidCollation, + .checkValue( + collationName => { + try { + CollationFactory.fetchCollation(collationName) + true + } catch { + case e: SparkException if e.getErrorClass == "COLLATION_INVALID_NAME" => false + } + }, "DEFAULT_COLLATION", - name => - Map( - "proposal" -> CollationFactory.getClosestCollation(name) - )) + _ => Map()) .createWithDefault("UTF8_BINARY") val FETCH_SHUFFLE_BLOCKS_IN_BATCH = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala index 537bac9aae9b..c3495a0c112c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala @@ -62,7 +62,7 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { exception = intercept[SparkException] { Collate(Literal("abc"), "UTF8_BS") }, errorClass = "COLLATION_INVALID_NAME", sqlState = "42704", - parameters = Map("proposal" -> "UTF8_BINARY", "collationName" -> "UTF8_BS")) + parameters = Map("collationName" -> "UTF8_BS")) } test("collation on non-explicit default collation") { @@ -71,7 +71,8 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("collation on explicitly collated string") { checkEvaluation( - Collation(Literal.create("abc", StringType(1))).replacement, + Collation(Literal.create("abc", + StringType(CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID))).replacement, "UTF8_BINARY_LCASE") checkEvaluation( Collation(Collate(Literal("abc"), "UTF8_BINARY_LCASE")).replacement, @@ -161,4 +162,32 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ArrayExcept(left, right), out) } } + + test("collation name normalization in collation expression") { + Seq( + ("en_USA", "en_USA"), + ("en_CS", "en"), + ("en_AS", "en"), + ("en_CS_AS", "en"), + ("en_AS_CS", "en"), + ("en_CI", "en_CI"), + ("en_AI", "en_AI"), + ("en_AI_CI", "en_CI_AI"), + ("en_CI_AI", "en_CI_AI"), + ("en_CS_AI", "en_AI"), + ("en_AI_CS", "en_AI"), + ("en_CI_AS", "en_CI"), + ("en_AS_CI", "en_CI"), + ("en_USA_AI_CI", "en_USA_CI_AI"), + // randomized case + ("EN_USA", "en_USA"), + ("SR_CYRL", "sr_Cyrl"), + ("sr_cyrl_srb", "sr_Cyrl_SRB"), + ("sR_cYRl_sRb", "sr_Cyrl_SRB") + ).foreach { + case (collation, normalized) => + checkEvaluation(Collation(Literal.create("abc", StringType(collation))).replacement, + normalized) + } + } } diff --git a/sql/core/src/test/resources/collations/ICU-collations-map.md b/sql/core/src/test/resources/collations/ICU-collations-map.md new file mode 100644 index 000000000000..598c3c4b4024 --- /dev/null +++ b/sql/core/src/test/resources/collations/ICU-collations-map.md @@ -0,0 +1,143 @@ +<!-- Automatically generated by ICUCollationsMapSuite --> +## ICU locale ids to name map +| Locale id | Locale name | +| --------- | ----------- | +| 0 | UNICODE | +| 1 | af | +| 2 | am | +| 3 | ar | +| 4 | ar_SAU | +| 5 | as | +| 6 | az | +| 7 | be | +| 8 | bg | +| 9 | bn | +| 10 | bo | +| 11 | br | +| 12 | bs | +| 13 | bs_Cyrl | +| 14 | ca | +| 15 | ceb | +| 16 | chr | +| 17 | cs | +| 18 | cy | +| 19 | da | +| 20 | de | +| 21 | de_AUT | +| 22 | dsb | +| 23 | dz | +| 24 | ee | +| 25 | el | +| 26 | en | +| 27 | en_USA | +| 28 | eo | +| 29 | es | +| 30 | et | +| 31 | fa | +| 32 | fa_AFG | +| 33 | ff | +| 34 | ff_Adlm | +| 35 | fi | +| 36 | fil | +| 37 | fo | +| 38 | fr | +| 39 | fr_CAN | +| 40 | fy | +| 41 | ga | +| 42 | gl | +| 43 | gu | +| 44 | ha | +| 45 | haw | +| 46 | he | +| 47 | he_ISR | +| 48 | hi | +| 49 | hr | +| 50 | hsb | +| 51 | hu | +| 52 | hy | +| 53 | id | +| 54 | id_IDN | +| 55 | ig | +| 56 | is | +| 57 | it | +| 58 | ja | +| 59 | ka | +| 60 | kk | +| 61 | kl | +| 62 | km | +| 63 | kn | +| 64 | ko | +| 65 | kok | +| 66 | ku | +| 67 | ky | +| 68 | lb | +| 69 | lkt | +| 70 | ln | +| 71 | lo | +| 72 | lt | +| 73 | lv | +| 74 | mk | +| 75 | ml | +| 76 | mn | +| 77 | mr | +| 78 | ms | +| 79 | mt | +| 80 | my | +| 81 | nb | +| 82 | nb_NOR | +| 83 | ne | +| 84 | nl | +| 85 | nn | +| 86 | no | +| 87 | om | +| 88 | or | +| 89 | pa | +| 90 | pa_Guru | +| 91 | pa_Guru_IND | +| 92 | pl | +| 93 | ps | +| 94 | pt | +| 95 | ro | +| 96 | ru | +| 97 | sa | +| 98 | se | +| 99 | si | +| 100 | sk | +| 101 | sl | +| 102 | smn | +| 103 | sq | +| 104 | sr | +| 105 | sr_Cyrl | +| 106 | sr_Cyrl_BIH | +| 107 | sr_Cyrl_MNE | +| 108 | sr_Cyrl_SRB | +| 109 | sr_Latn | +| 110 | sr_Latn_BIH | +| 111 | sr_Latn_SRB | +| 112 | sv | +| 113 | sw | +| 114 | ta | +| 115 | te | +| 116 | th | +| 117 | tk | +| 118 | to | +| 119 | tr | +| 120 | ug | +| 121 | uk | +| 122 | ur | +| 123 | uz | +| 124 | vi | +| 125 | wae | +| 126 | wo | +| 127 | xh | +| 128 | yi | +| 129 | yo | +| 130 | zh | +| 131 | zh_Hans | +| 132 | zh_Hans_CHN | +| 133 | zh_Hans_SGP | +| 134 | zh_Hant | +| 135 | zh_Hant_HKG | +| 136 | zh_Hant_MAC | +| 137 | zh_Hant_TWN | +| 138 | zu | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index d242a60a17c1..9a1f4ed1f8e5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -312,3 +312,80 @@ select array_except(array('aaa' collate utf8_binary_lcase), array('AAA' collate -- !query analysis Project [array_except(array(collate(aaa, utf8_binary_lcase)), array(collate(AAA, utf8_binary_lcase))) AS array_except(array(collate(aaa)), array(collate(AAA)))#x] +- OneRowRelation + + +-- !query +select 'a' collate unicode < 'A' +-- !query analysis +Project [(collate(a, unicode) < cast(A as string collate UNICODE)) AS (collate(a) < A)#x] ++- OneRowRelation + + +-- !query +select 'a' collate unicode_ci = 'A' +-- !query analysis +Project [(collate(a, unicode_ci) = cast(A as string collate UNICODE_CI)) AS (collate(a) = A)#x] ++- OneRowRelation + + +-- !query +select 'a' collate unicode_ai = 'å' +-- !query analysis +Project [(collate(a, unicode_ai) = cast(å as string collate UNICODE_AI)) AS (collate(a) = å)#x] ++- OneRowRelation + + +-- !query +select 'a' collate unicode_ci_ai = 'Å' +-- !query analysis +Project [(collate(a, unicode_ci_ai) = cast(Å as string collate UNICODE_CI_AI)) AS (collate(a) = Å)#x] ++- OneRowRelation + + +-- !query +select 'a' collate en < 'A' +-- !query analysis +Project [(collate(a, en) < cast(A as string collate en)) AS (collate(a) < A)#x] ++- OneRowRelation + + +-- !query +select 'a' collate en_ci = 'A' +-- !query analysis +Project [(collate(a, en_ci) = cast(A as string collate en_CI)) AS (collate(a) = A)#x] ++- OneRowRelation + + +-- !query +select 'a' collate en_ai = 'å' +-- !query analysis +Project [(collate(a, en_ai) = cast(å as string collate en_AI)) AS (collate(a) = å)#x] ++- OneRowRelation + + +-- !query +select 'a' collate en_ci_ai = 'Å' +-- !query analysis +Project [(collate(a, en_ci_ai) = cast(Å as string collate en_CI_AI)) AS (collate(a) = Å)#x] ++- OneRowRelation + + +-- !query +select 'Kypper' collate sv < 'Köpfe' +-- !query analysis +Project [(collate(Kypper, sv) < cast(Köpfe as string collate sv)) AS (collate(Kypper) < Köpfe)#x] ++- OneRowRelation + + +-- !query +select 'Kypper' collate de > 'Köpfe' +-- !query analysis +Project [(collate(Kypper, de) > cast(Köpfe as string collate de)) AS (collate(Kypper) > Köpfe)#x] ++- OneRowRelation + + +-- !query +select 'I' collate tr_ci = 'ı' +-- !query analysis +Project [(collate(I, tr_ci) = cast(ı as string collate tr_CI)) AS (collate(I) = ı)#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql b/sql/core/src/test/resources/sql-tests/inputs/collations.sql index 619eb4470e9a..6bb0a0163443 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql @@ -77,3 +77,16 @@ select array_distinct(array('aaa' collate utf8_binary_lcase, 'AAA' collate utf8_ select array_union(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)); select array_intersect(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)); select array_except(array('aaa' collate utf8_binary_lcase), array('AAA' collate utf8_binary_lcase)); + +-- ICU collations (all statements return true) +select 'a' collate unicode < 'A'; +select 'a' collate unicode_ci = 'A'; +select 'a' collate unicode_ai = 'å'; +select 'a' collate unicode_ci_ai = 'Å'; +select 'a' collate en < 'A'; +select 'a' collate en_ci = 'A'; +select 'a' collate en_ai = 'å'; +select 'a' collate en_ci_ai = 'Å'; +select 'Kypper' collate sv < 'Köpfe'; +select 'Kypper' collate de > 'Köpfe'; +select 'I' collate tr_ci = 'ı'; diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index 4485191ba1f3..96c875306d35 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -339,3 +339,91 @@ select array_except(array('aaa' collate utf8_binary_lcase), array('AAA' collate struct<array_except(array(collate(aaa)), array(collate(AAA))):array<string collate UTF8_BINARY_LCASE>> -- !query output [] + + +-- !query +select 'a' collate unicode < 'A' +-- !query schema +struct<(collate(a) < A):boolean> +-- !query output +true + + +-- !query +select 'a' collate unicode_ci = 'A' +-- !query schema +struct<(collate(a) = A):boolean> +-- !query output +true + + +-- !query +select 'a' collate unicode_ai = 'å' +-- !query schema +struct<(collate(a) = å):boolean> +-- !query output +true + + +-- !query +select 'a' collate unicode_ci_ai = 'Å' +-- !query schema +struct<(collate(a) = Å):boolean> +-- !query output +true + + +-- !query +select 'a' collate en < 'A' +-- !query schema +struct<(collate(a) < A):boolean> +-- !query output +true + + +-- !query +select 'a' collate en_ci = 'A' +-- !query schema +struct<(collate(a) = A):boolean> +-- !query output +true + + +-- !query +select 'a' collate en_ai = 'å' +-- !query schema +struct<(collate(a) = å):boolean> +-- !query output +true + + +-- !query +select 'a' collate en_ci_ai = 'Å' +-- !query schema +struct<(collate(a) = Å):boolean> +-- !query output +true + + +-- !query +select 'Kypper' collate sv < 'Köpfe' +-- !query schema +struct<(collate(Kypper) < Köpfe):boolean> +-- !query output +true + + +-- !query +select 'Kypper' collate de > 'Köpfe' +-- !query schema +struct<(collate(Kypper) > Köpfe):boolean> +-- !query output +true + + +-- !query +select 'I' collate tr_ci = 'ı' +-- !query schema +struct<(collate(I) = ı):boolean> +-- !query output +true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 657fd4504cac..4f8587395b3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -152,7 +152,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { exception = intercept[SparkException] { sql("select 'aaa' collate UTF8_BS") }, errorClass = "COLLATION_INVALID_NAME", sqlState = "42704", - parameters = Map("proposal" -> "UTF8_BINARY", "collationName" -> "UTF8_BS")) + parameters = Map("collationName" -> "UTF8_BS")) } test("disable bucketing on collated string column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ICUCollationsMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ICUCollationsMapSuite.scala new file mode 100644 index 000000000000..42d486bd7545 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ICUCollationsMapSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile, CollationFactory} + +// scalastyle:off line.size.limit +/** + * Guard against breaking changes in ICU locale names and codes supported by Collator class and provider by CollationFactory. + * Map is in form of rows of pairs (locale name, locale id); locale name consists of three parts: + * - 2-letter lowercase language code + * - 4-letter script code (optional) + * - 3-letter uppercase country code + * + * To re-generate collations map golden file, run: + * {{{ + * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly org.apache.spark.sql.ICUCollationsMapSuite" + * }}} + */ +// scalastyle:on line.size.limit +class ICUCollationsMapSuite extends SparkFunSuite { + + private val collationsMapFile = { + getWorkspaceFilePath("sql", "core", "src", "test", "resources", + "collations", "ICU-collations-map.md").toFile + } + + if (regenerateGoldenFiles) { + val map = CollationFactory.getICULocaleNames + val mapOutput = map.zipWithIndex.map { + case (localeName, idx) => s"| $idx | $localeName |" }.mkString("\n") + val goldenOutput = { + s"<!-- Automatically generated by ${getClass.getSimpleName} -->\n" + + "## ICU locale ids to name map\n" + + "| Locale id | Locale name |\n" + + "| --------- | ----------- |\n" + + mapOutput + "\n" + } + val parent = collationsMapFile.getParentFile + if (!parent.exists()) { + assert(parent.mkdirs(), "Could not create directory: " + parent) + } + stringToFile(collationsMapFile, goldenOutput) + } + + test("ICU locales map breaking change") { + val goldenLines = fileToString(collationsMapFile).split('\n') + val goldenRelevantLines = goldenLines.slice(4, goldenLines.length) // skip header + val input = goldenRelevantLines.map( + s => (s.split('|')(2).strip(), s.split('|')(1).strip().toInt)) + assert(input sameElements CollationFactory.getICULocaleNames.zipWithIndex) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 213dfd32c869..8d291591c5f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -518,8 +518,7 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { errorClass = "INVALID_CONF_VALUE.DEFAULT_COLLATION", parameters = Map( "confValue" -> "UNICODE_C", - "confName" -> "spark.sql.session.collation.default", - "proposal" -> "UNICODE_CI" + "confName" -> "spark.sql.session.collation.default" )) withSQLConf(SQLConf.COLLATION_ENABLED.key -> "false") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org