nikolamand-db commented on code in PR #46180:
URL: https://github.com/apache/spark/pull/46180#discussion_r1601363774


##########
common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java:
##########
@@ -117,76 +119,445 @@ public Collation(
     }
 
     /**
-     * Constructor with comparators that are inherited from the given collator.
+     * collation id (32-bit integer) layout:
+     * bit 31:    0 = predefined collation, 1 = user-defined collation
+     * bit 30-29: 00 = utf8-binary, 01 = ICU, 10 = indeterminate (without spec 
implementation)
+     * bit 28:    0 for utf8-binary / 0 = case-sensitive, 1 = case-insensitive 
for ICU
+     * bit 27:    0 for utf8-binary / 0 = accent-sensitive, 1 = 
accent-insensitive for ICU
+     * bit 26-25: zeroes, reserved for punctuation sensitivity
+     * bit 24-23: zeroes, reserved for first letter preference
+     * bit 22-21: 00 = unspecified, 01 = to-lower, 10 = to-upper
+     * bit 20-19: zeroes, reserved for space trimming
+     * bit 18-17: zeroes, reserved for version
+     * bit 16-12: zeroes
+     * bit 11-0:  zeroes for utf8-binary / locale id for ICU
      */
-    public Collation(
-        String collationName,
-        Collator collator,
-        String version,
-        boolean supportsBinaryEquality,
-        boolean supportsBinaryOrdering,
-        boolean supportsLowercaseEquality) {
-      this(
-        collationName,
-        collator,
-        (s1, s2) -> collator.compare(s1.toString(), s2.toString()),
-        version,
-        s -> (long)collator.getCollationKey(s.toString()).hashCode(),
-        supportsBinaryEquality,
-        supportsBinaryOrdering,
-        supportsLowercaseEquality);
+    private abstract static class CollationSpec {
+      protected enum ImplementationProvider {
+        UTF8_BINARY, ICU, INDETERMINATE
+      }
+
+      protected enum CaseSensitivity {
+        CS, CI
+      }
+
+      protected enum AccentSensitivity {
+        AS, AI
+      }
+
+      protected enum CaseConversion {
+        UNSPECIFIED, LCASE, UCASE
+      }
+
+      protected static final int IMPLEMENTATION_PROVIDER_OFFSET = 29;
+      protected static final int IMPLEMENTATION_PROVIDER_MASK = 0b11;
+      protected static final int CASE_SENSITIVITY_OFFSET = 28;
+      protected static final int CASE_SENSITIVITY_MASK = 0b1;
+      protected static final int ACCENT_SENSITIVITY_OFFSET = 27;
+      protected static final int ACCENT_SENSITIVITY_MASK = 0b1;
+      protected static final int CASE_CONVERSION_OFFSET = 21;
+      protected static final int CASE_CONVERSION_MASK = 0b11;
+      protected static final int LOCALE_OFFSET = 0;
+      protected static final int LOCALE_MASK = 0x0FFF;
+
+      protected static final int INDETERMINATE_COLLATION_ID =
+        ImplementationProvider.INDETERMINATE.ordinal() << 
IMPLEMENTATION_PROVIDER_OFFSET;
+
+      protected final CaseSensitivity caseSensitivity;
+      protected final AccentSensitivity accentSensitivity;
+      protected final CaseConversion caseConversion;
+      protected final String locale;
+      protected final int collationId;
+
+      protected CollationSpec(
+          String locale,
+          CaseSensitivity caseSensitivity,
+          AccentSensitivity accentSensitivity,
+          CaseConversion caseConversion) {
+        this.locale = locale;
+        this.caseSensitivity = caseSensitivity;
+        this.accentSensitivity = accentSensitivity;
+        this.caseConversion = caseConversion;
+        this.collationId = getCollationId();
+      }
+
+      private static final Map<Integer, Collation> collationMap = new 
ConcurrentHashMap<>();
+
+      public static Collation fetchCollation(int collationId) throws 
SparkException {
+        if (collationId == UTF8_BINARY_COLLATION_ID) {
+          return CollationSpecUTF8Binary.UTF8_BINARY_COLLATION;
+        } else if (collationMap.containsKey(collationId)) {
+          return collationMap.get(collationId);
+        } else {
+          CollationSpec spec;
+          int implementationProviderOrdinal =
+            (collationId >> IMPLEMENTATION_PROVIDER_OFFSET) & 
IMPLEMENTATION_PROVIDER_MASK;
+          if (implementationProviderOrdinal >= 
ImplementationProvider.values().length) {
+            throw SparkException.internalError("Invalid collation 
implementation provider");
+          } else {
+            ImplementationProvider implementationProvider = 
ImplementationProvider.values()[
+              implementationProviderOrdinal];
+            if (implementationProvider == ImplementationProvider.UTF8_BINARY) {
+              spec = CollationSpecUTF8Binary.fromCollationId(collationId);
+            } else if (implementationProvider == ImplementationProvider.ICU) {
+              spec = CollationSpecICU.fromCollationId(collationId);
+            } else {
+              throw SparkException.internalError("Cannot instantiate 
indeterminate collation");
+            }
+            Collation collation = spec.buildCollation();
+            collationMap.put(collationId, collation);
+            return collation;
+          }
+        }
+      }
+
+      protected static SparkException collationInvalidNameException(String 
collationName) {
+        return new SparkException("COLLATION_INVALID_NAME",
+          SparkException.constructMessageParams(Map.of("collationName", 
collationName)), null);
+      }
+
+      public static int collationNameToId(String collationName) throws 
SparkException {
+        String collationNameUpper = collationName.toUpperCase();
+        if (collationNameUpper.startsWith("UTF8_BINARY")) {
+          return CollationSpecUTF8Binary.collationNameToId(collationName, 
collationNameUpper);
+        } else {
+          return CollationSpecICU.collationNameToId(collationName, 
collationNameUpper);
+        }
+      }
+
+      protected static int parseSpecifiers(
+          String originalName, String collationName, int splitStart) throws 
SparkException {
+        int specifiers = 0;
+        String[] parts = collationName.substring(splitStart).split("_");
+        for (String part : parts) {
+          if (!part.isEmpty()) {
+            if (part.equals("UNSPECIFIED") || part.equals("INDETERMINATE")) {
+              throw collationInvalidNameException(originalName);
+            } else if (Arrays.stream(CaseSensitivity.values()).anyMatch(
+                (s) -> s.toString().equals(part))) {
+              specifiers |=
+                CaseSensitivity.valueOf(part).ordinal() << 
CASE_SENSITIVITY_OFFSET;
+            } else if (Arrays.stream(AccentSensitivity.values()).anyMatch(
+                (s) -> part.equals(s.toString()))) {
+              specifiers |=
+                AccentSensitivity.valueOf(part).ordinal() << 
ACCENT_SENSITIVITY_OFFSET;
+            } else if (Arrays.stream(CaseConversion.values()).anyMatch(
+                (s) -> part.equals(s.toString()))) {
+              specifiers |=
+                CaseConversion.valueOf(part).ordinal() << 
CASE_CONVERSION_OFFSET;
+            } else {
+              throw collationInvalidNameException(originalName);
+            }
+          }
+        }
+        return specifiers;
+      }
+
+      protected abstract int getCollationId();
+      protected abstract Collation buildCollation();
     }
-  }
 
-  private static final Collation[] collationTable = new Collation[4];
-  private static final HashMap<String, Integer> collationNameToIdMap = new 
HashMap<>();
-
-  public static final int UTF8_BINARY_COLLATION_ID = 0;
-  public static final int UTF8_BINARY_LCASE_COLLATION_ID = 1;
-
-  static {
-    // Binary comparison. This is the default collation.
-    // No custom comparators will be used for this collation.
-    // Instead, we rely on byte for byte comparison.
-    collationTable[0] = new Collation(
-      "UTF8_BINARY",
-      null,
-      UTF8String::binaryCompare,
-      "1.0",
-      s -> (long)s.hashCode(),
-      true,
-      true,
-      false);
-
-    // Case-insensitive UTF8 binary collation.
-    // TODO: Do in place comparisons instead of creating new strings.
-    collationTable[1] = new Collation(
-      "UTF8_BINARY_LCASE",
-      null,
-      UTF8String::compareLowerCase,
-      "1.0",
-      (s) -> (long)s.toLowerCase().hashCode(),
-      false,
-      false,
-      true);
-
-    // UNICODE case sensitive comparison (ROOT locale, in ICU).
-    collationTable[2] = new Collation(
-      "UNICODE", Collator.getInstance(ULocale.ROOT), "153.120.0.0", true, 
false, false);
-    collationTable[2].collator.setStrength(Collator.TERTIARY);
-    collationTable[2].collator.freeze();
-
-    // UNICODE case-insensitive comparison (ROOT locale, in ICU + Secondary 
strength).
-    collationTable[3] = new Collation(
-      "UNICODE_CI", Collator.getInstance(ULocale.ROOT), "153.120.0.0", false, 
false, false);
-    collationTable[3].collator.setStrength(Collator.SECONDARY);
-    collationTable[3].collator.freeze();
-
-    for (int i = 0; i < collationTable.length; i++) {
-      collationNameToIdMap.put(collationTable[i].collationName, i);
+    private static class CollationSpecUTF8Binary extends CollationSpec {
+
+      public static final int UTF8_BINARY_COLLATION_ID =
+        new 
CollationSpecUTF8Binary(CaseConversion.UNSPECIFIED).getCollationId();
+      public static final int UTF8_BINARY_LCASE_COLLATION_ID =
+        new CollationSpecUTF8Binary(CaseConversion.LCASE).getCollationId();
+      public static Collation UTF8_BINARY_COLLATION =
+        new 
CollationSpecUTF8Binary(CaseConversion.UNSPECIFIED).buildCollation();
+
+      private CollationSpecUTF8Binary(CaseConversion caseConversion) {
+        super(null, CaseSensitivity.CS, AccentSensitivity.AS, caseConversion);
+      }
+
+      public static int collationNameToId(
+          String originalName, String collationName) throws SparkException {
+        int collationId = 0;
+        int specifiers = CollationSpec.parseSpecifiers(
+          originalName, collationName, "UTF8_BINARY".length());
+        if ((specifiers & ~(CASE_CONVERSION_MASK << CASE_CONVERSION_OFFSET)) 
!= 0) {
+          throw collationInvalidNameException(originalName);
+        }
+        collationId |= specifiers;
+        return collationId;
+      }
+
+      @Override
+      protected int getCollationId() {
+        int collationId = 0;
+        collationId |= caseConversion.ordinal() << CASE_CONVERSION_OFFSET;
+        return collationId;
+      }
+
+      public static CollationSpecUTF8Binary fromCollationId(int collationId)
+          throws SparkException {
+        int originalCollationId = collationId;
+        int caseConversionOrdinal =
+          (collationId >> CASE_CONVERSION_OFFSET) & CASE_CONVERSION_MASK;
+        collationId ^= caseConversionOrdinal << CASE_CONVERSION_OFFSET;
+        if (collationId != 0 || caseConversionOrdinal >= 
CaseConversion.values().length) {
+          throw SparkException.internalError("Invalid UTF8_BINARY collation id 
" +
+            originalCollationId);
+        } else {
+          CaseConversion caseConversion = 
CaseConversion.values()[caseConversionOrdinal];
+          return new CollationSpecUTF8Binary(caseConversion);
+        }
+      }
+
+      @Override
+      protected Collation buildCollation() {
+        Comparator<UTF8String> comparator;
+        if (collationId == UTF8_BINARY_COLLATION_ID) {
+          comparator = UTF8String::binaryCompare;
+        } else if (collationId == UTF8_BINARY_LCASE_COLLATION_ID) {
+          comparator = UTF8String::compareLowerCase;
+        } else {
+          comparator = (s1, s2) -> {
+            UTF8String convertedS1 = caseConversion(s1);
+            UTF8String convertedS2 = caseConversion(s2);
+            return convertedS1.binaryCompare(convertedS2);
+          };
+        }
+        return new Collation(
+          collationName(),
+          null,
+          comparator,
+          "1.0",
+          s -> (long) caseConversion(s).hashCode(),
+          collationId == UTF8_BINARY_COLLATION_ID,
+          collationId == UTF8_BINARY_COLLATION_ID,
+          collationId == UTF8_BINARY_LCASE_COLLATION_ID
+        );
+      }
+
+      private UTF8String caseConversion(UTF8String s) {
+        if (caseConversion == CaseConversion.LCASE) {
+          return s.toLowerCase();
+        } else if (caseConversion == CaseConversion.UCASE) {
+          return s.toUpperCase();
+        } else {
+          return s;
+        }
+      }
+
+      private String collationName() {
+        StringBuilder builder = new StringBuilder();
+        builder.append("UTF8_BINARY");
+        if (caseConversion != CaseConversion.UNSPECIFIED) {
+          builder.append('_');
+          builder.append(caseConversion.toString());
+        }
+        return builder.toString();
+      }
+    }
+
+    private static class CollationSpecICU extends CollationSpec {
+
+      private static final String[] ICULocaleNames;
+      private static final Map<String, ULocale> ICULocaleMap = new HashMap<>();
+      private static final Map<String, String> ICULocaleMapUppercase = new 
HashMap<>();
+      private static final Map<String, Integer> ICULocaleToId = new 
HashMap<>();
+      private static final String ICUCollatorVersion = "153.120.0.0";
+
+      static {
+        ICULocaleMap.put("UNICODE", ULocale.ROOT);
+        ULocale[] locales = Collator.getAvailableULocales();
+        for (ULocale locale : locales) {
+          if (locale.getVariant().isEmpty()) {
+            String language = locale.getLanguage();
+            assert (!language.isEmpty());
+            StringBuilder builder = new StringBuilder(language);
+            String script = locale.getScript();
+            if (!script.isEmpty()) {
+              builder.append('_');
+              builder.append(script);
+            }
+            String country = locale.getISO3Country();
+            if (!country.isEmpty()) {
+              builder.append('_');
+              builder.append(country);
+            }
+            String localeName = builder.toString();
+            assert (!ICULocaleMap.containsKey(localeName));
+            ICULocaleMap.put(localeName, locale);
+          }
+        }
+        for (String localeName : ICULocaleMap.keySet()) {
+          String localeUppercase = localeName.toUpperCase();
+          assert (!ICULocaleMapUppercase.containsKey(localeUppercase));
+          ICULocaleMapUppercase.put(localeUppercase, localeName);
+        }
+        ICULocaleNames = ICULocaleMap.keySet().toArray(new String[0]);
+        Arrays.sort(ICULocaleNames);
+        assert (ICULocaleNames.length <= (1 << 16));
+        for (int i = 0; i < ICULocaleNames.length; i++) {
+          ICULocaleToId.put(ICULocaleNames[i], i);
+        }

Review Comment:
   The goal is to have stable mapping between integer ids and locale names. 
Please check generated golden 
[file](https://github.com/apache/spark/pull/46180/files#diff-23fa9a48d97f6a09df9ffbaffce69c040bbd1be2a6cf430242adf9597839b3e3)
 which outlines this mapping. It's also described as part of collation id 
[spec](https://github.com/apache/spark/pull/46180/files#diff-640c14aa5d7473df79b2435ce5a327dffcc16ca29354b153956b4f8d19fdb16cR133).
 Any changes to the mapping locale name - id mapping will require golden file 
update to add more visibility.
   
   However, since the collation ids and therefore locale ids (as part of 
collation id) are now internal-only and never exposed to user (please confirm 
@dbatomic @stefankandic), they don't have to strictly remain the same across 
different Spark versions. That said, it could be problematic to have executors 
with different Spark versions with different locale name - id mapping.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to