This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6f6b4860268d [SPARK-48175][SQL][PYTHON] Store collation information in 
metadata and not in type for SER/DE
6f6b4860268d is described below

commit 6f6b4860268dc250d8e31a251d740733798aa512
Author: Stefan Kandic <stefan.kan...@databricks.com>
AuthorDate: Sat May 18 15:17:56 2024 +0800

    [SPARK-48175][SQL][PYTHON] Store collation information in metadata and not 
in type for SER/DE
    
    ### What changes were proposed in this pull request?
    Changing serialization and deserialization of collated strings so that the 
collation information is put in the metadata of the enclosing struct field - 
and then read back from there during parsing.
    
    Format of serialization will look something like this:
    ```json
    {
      "type": "struct",
      "fields": [
        "name": "colName",
        "type": "string",
        "nullable": true,
        "metadata": {
          "__COLLATIONS": {
            "colName": "UNICODE"
          }
        }
      ]
    }
    ```
    
    If we have a map we will add suffixes `.key` and `.value` in the metadata:
    ```json
    {
      "type": "struct",
      "fields": [
        {
          "name": "mapField",
          "type": {
            "type": "map",
            "keyType": "string",
            "valueType": "string",
            "valueContainsNull": true
          },
          "nullable": true,
          "metadata": {
            "__COLLATIONS": {
              "mapField.key": "UNICODE",
              "mapField.value": "UNICODE"
            }
          }
        }
      ]
    }
    ```
    It will be a similar story for arrays (we will add `.element` suffix). We 
could have multiple suffixes when working with deeply nested data types 
(Map[String, Array[Array[String]]] - see tests for this example)
    
    ### Why are the changes needed?
    Putting collation info in field metadata is the only way to not break old 
clients reading new tables with collations. `CharVarcharUtils` does a similar 
thing but this is much less hacky, and more friendly for all 3p clients - which 
is especially important since delta also uses spark for schema ser/de.
    
    It will also remove the need for additional logic introduced in #46083 to 
remove collations before writing to HMS as this way the tables will be fully 
HMS compatible.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    With unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46280 from stefankandic/newDeltaSchema.
    
    Lead-authored-by: Stefan Kandic <stefan.kan...@databricks.com>
    Co-authored-by: Stefan Kandic 
<154237371+stefankan...@users.noreply.github.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/util/CollationFactory.java  |  99 +++++++-
 .../src/main/resources/error/error-conditions.json |  12 +
 python/pyspark/errors/error-conditions.json        |  10 +
 .../pyspark/sql/tests/connect/test_parity_types.py |   4 +
 python/pyspark/sql/tests/test_types.py             | 249 +++++++++++++++++++--
 python/pyspark/sql/types.py                        | 178 +++++++++++++--
 .../org/apache/spark/sql/types/DataType.scala      |  74 +++++-
 .../org/apache/spark/sql/types/StringType.scala    |   7 +
 .../org/apache/spark/sql/types/StructField.scala   |  62 ++++-
 .../org/apache/spark/sql/types/DataTypeSuite.scala | 181 ++++++++++++++-
 .../apache/spark/sql/types/StructTypeSuite.scala   | 183 +++++++++++++++
 .../streaming/StreamingDeduplicationSuite.scala    |   2 +-
 .../spark/sql/streaming/StreamingQuerySuite.scala  |   2 +-
 13 files changed, 1004 insertions(+), 59 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
index 8ffff63445b6..0133c3feb611 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
@@ -36,11 +36,62 @@ import org.apache.spark.unsafe.types.UTF8String;
  * Provides functionality to the UTF8String object which respects defined 
collation settings.
  */
 public final class CollationFactory {
+
+  /**
+   * Identifier for single a collation.
+   */
+  public static class CollationIdentifier {
+    private final String provider;
+    private final String name;
+    private final String version;
+
+    public CollationIdentifier(String provider, String collationName, String 
version) {
+      this.provider = provider;
+      this.name = collationName;
+      this.version = version;
+    }
+
+    public static CollationIdentifier fromString(String identifier) {
+      long numDots = identifier.chars().filter(ch -> ch == '.').count();
+      assert(numDots > 0);
+
+      if (numDots == 1) {
+        String[] parts = identifier.split("\\.", 2);
+        return new CollationIdentifier(parts[0], parts[1], null);
+      }
+
+      String[] parts = identifier.split("\\.", 3);
+      return new CollationIdentifier(parts[0], parts[1], parts[2]);
+    }
+
+    /**
+     * Returns the identifier's string value without the version.
+     * This is used for the table schema as the schema doesn't care about the 
version,
+     * only the statistics do.
+     */
+    public String toStringWithoutVersion() {
+      return String.format("%s.%s", provider, name);
+    }
+
+    public String getProvider() {
+      return provider;
+    }
+
+    public String getName() {
+      return name;
+    }
+
+    public Optional<String> getVersion() {
+      return Optional.ofNullable(version);
+    }
+  }
+
   /**
    * Entry encapsulating all information about a collation.
    */
   public static class Collation {
     public final String collationName;
+    public final String provider;
     public final Collator collator;
     public final Comparator<UTF8String> comparator;
 
@@ -89,6 +140,7 @@ public final class CollationFactory {
 
     public Collation(
         String collationName,
+        String provider,
         Collator collator,
         Comparator<UTF8String> comparator,
         String version,
@@ -97,6 +149,7 @@ public final class CollationFactory {
         boolean supportsBinaryOrdering,
         boolean supportsLowercaseEquality) {
       this.collationName = collationName;
+      this.provider = provider;
       this.collator = collator;
       this.comparator = comparator;
       this.version = version;
@@ -110,6 +163,8 @@ public final class CollationFactory {
       // No Collation can simultaneously support binary equality and lowercase 
equality
       assert(!supportsBinaryEquality || !supportsLowercaseEquality);
 
+      assert(SUPPORTED_PROVIDERS.contains(provider));
+
       if (supportsBinaryEquality) {
         this.equalsFunction = UTF8String::equals;
       } else {
@@ -122,6 +177,7 @@ public final class CollationFactory {
      */
     public Collation(
         String collationName,
+        String provider,
         Collator collator,
         String version,
         boolean supportsBinaryEquality,
@@ -129,6 +185,7 @@ public final class CollationFactory {
         boolean supportsLowercaseEquality) {
       this(
         collationName,
+        provider,
         collator,
         (s1, s2) -> collator.compare(s1.toString(), s2.toString()),
         version,
@@ -137,6 +194,11 @@ public final class CollationFactory {
         supportsBinaryOrdering,
         supportsLowercaseEquality);
     }
+
+    /** Returns the collation identifier. */
+    public CollationIdentifier identifier() {
+      return new CollationIdentifier(provider, collationName, version);
+    }
   }
 
   private static final Collation[] collationTable = new Collation[4];
@@ -145,12 +207,17 @@ public final class CollationFactory {
   public static final int UTF8_BINARY_COLLATION_ID = 0;
   public static final int UTF8_BINARY_LCASE_COLLATION_ID = 1;
 
+  public static final String PROVIDER_SPARK = "spark";
+  public static final String PROVIDER_ICU = "icu";
+  public static final List<String> SUPPORTED_PROVIDERS = 
List.of(PROVIDER_SPARK, PROVIDER_ICU);
+
   static {
     // Binary comparison. This is the default collation.
     // No custom comparators will be used for this collation.
     // Instead, we rely on byte for byte comparison.
     collationTable[0] = new Collation(
       "UTF8_BINARY",
+      PROVIDER_SPARK,
       null,
       UTF8String::binaryCompare,
       "1.0",
@@ -163,6 +230,7 @@ public final class CollationFactory {
     // TODO: Do in place comparisons instead of creating new strings.
     collationTable[1] = new Collation(
       "UTF8_BINARY_LCASE",
+      PROVIDER_SPARK,
       null,
       UTF8String::compareLowerCase,
       "1.0",
@@ -173,13 +241,28 @@ public final class CollationFactory {
 
     // UNICODE case sensitive comparison (ROOT locale, in ICU).
     collationTable[2] = new Collation(
-      "UNICODE", Collator.getInstance(ULocale.ROOT), "153.120.0.0", true, 
false, false);
+      "UNICODE",
+      PROVIDER_ICU,
+      Collator.getInstance(ULocale.ROOT),
+      "153.120.0.0",
+      true,
+      false,
+      false
+    );
+
     collationTable[2].collator.setStrength(Collator.TERTIARY);
     collationTable[2].collator.freeze();
 
     // UNICODE case-insensitive comparison (ROOT locale, in ICU + Secondary 
strength).
     collationTable[3] = new Collation(
-      "UNICODE_CI", Collator.getInstance(ULocale.ROOT), "153.120.0.0", false, 
false, false);
+      "UNICODE_CI",
+      PROVIDER_ICU,
+      Collator.getInstance(ULocale.ROOT),
+      "153.120.0.0",
+      false,
+      false,
+      false
+    );
     collationTable[3].collator.setStrength(Collator.SECONDARY);
     collationTable[3].collator.freeze();
 
@@ -263,6 +346,18 @@ public final class CollationFactory {
     }
   }
 
+  public static void assertValidProvider(String provider) throws 
SparkException {
+    if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) {
+      Map<String, String> params = Map.of(
+        "provider", provider,
+        "supportedProviders", String.join(", ", SUPPORTED_PROVIDERS)
+      );
+
+      throw new SparkException(
+        "COLLATION_INVALID_PROVIDER", 
SparkException.constructMessageParams(params), null);
+    }
+  }
+
   public static Collation fetchCollation(int collationId) {
     return collationTable[collationId];
   }
diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 69889435b02e..c1c0cd6bfb39 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -473,6 +473,12 @@
     ],
     "sqlState" : "42704"
   },
+  "COLLATION_INVALID_PROVIDER" : {
+    "message" : [
+      "The value <provider> does not represent a correct collation provider. 
Supported providers are: [<supportedProviders>]."
+    ],
+    "sqlState" : "42704"
+  },
   "COLLATION_MISMATCH" : {
     "message" : [
       "Could not determine which collation to use for string functions and 
operators."
@@ -2342,6 +2348,12 @@
     ],
     "sqlState" : "2203G"
   },
+  "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS" : {
+    "message" : [
+      "Collations can only be applied to string types, but the JSON data type 
is <jsonType>."
+    ],
+    "sqlState" : "2203G"
+  },
   "INVALID_JSON_ROOT_FIELD" : {
     "message" : [
       "Cannot convert JSON root field to target Spark type."
diff --git a/python/pyspark/errors/error-conditions.json 
b/python/pyspark/errors/error-conditions.json
index 906bf781e1bb..30db37387249 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -149,6 +149,11 @@
       "Cannot <condition1> without <condition2>."
     ]
   },
+  "COLLATION_INVALID_PROVIDER" : {
+    "message" : [
+      "The value <provider> does not represent a correct collation provider. 
Supported providers are: [<supportedProviders>]."
+    ]
+  },
   "COLUMN_IN_LIST": {
     "message": [
       "`<func_name>` does not allow a Column in a list."
@@ -357,6 +362,11 @@
       "All items in `<arg_name>` should be in <allowed_types>, got 
<item_type>."
     ]
   },
+  "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS" : {
+    "message" : [
+      "Collations can only be applied to string types, but the JSON data type 
is <jsonType>."
+    ]
+  },
   "INVALID_MULTIPLE_ARGUMENT_CONDITIONS": {
     "message": [
       "[{arg_names}] cannot be <condition>."
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py 
b/python/pyspark/sql/tests/connect/test_parity_types.py
index 9e81af47ceb0..4f736ac1215c 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -90,6 +90,10 @@ class TypesParityTests(TypesTestsMixin, 
ReusedConnectTestCase):
     def test_udt(self):
         super().test_udt()
 
+    @unittest.skip("Requires JVM access.")
+    def test_schema_with_collations_json_ser_de(self):
+        super().test_schema_with_collations_json_ser_de()
+
     @unittest.skip("Does not test anything related to Spark Connect")
     def test_parse_datatype_string(self):
         super().test_parse_datatype_string()
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 5942ae2abdb3..4d6fc499b70b 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -600,6 +600,234 @@ class TypesTestsMixin:
         self.assertEqual(df.count(), 1)
         self.assertEqual(df.head(), Row(name="[123]", income=120))
 
+    def test_schema_with_collations_json_ser_de(self):
+        from pyspark.sql.types import _parse_datatype_json_string
+
+        unicode_collation = "UNICODE"
+
+        simple_struct = StructType([StructField("c1", 
StringType(unicode_collation))])
+
+        nested_struct = StructType([StructField("nested", simple_struct)])
+
+        array_in_schema = StructType(
+            [StructField("array", ArrayType(StringType(unicode_collation)))]
+        )
+
+        map_in_schema = StructType(
+            [
+                StructField(
+                    "map", MapType(StringType(unicode_collation), 
StringType(unicode_collation))
+                )
+            ]
+        )
+
+        nested_map = StructType(
+            [
+                StructField(
+                    "nested",
+                    StructType(
+                        [
+                            StructField(
+                                "mapField",
+                                MapType(
+                                    StringType(unicode_collation), 
StringType(unicode_collation)
+                                ),
+                            )
+                        ]
+                    ),
+                )
+            ]
+        )
+
+        array_in_map = StructType(
+            [
+                StructField(
+                    "arrInMap",
+                    MapType(
+                        StringType(unicode_collation), 
ArrayType(StringType(unicode_collation))
+                    ),
+                )
+            ]
+        )
+
+        nested_array_in_map_value = StructType(
+            [
+                StructField(
+                    "nestedArrayInMap",
+                    ArrayType(
+                        MapType(
+                            StringType(unicode_collation),
+                            
ArrayType(ArrayType(StringType(unicode_collation))),
+                        )
+                    ),
+                )
+            ]
+        )
+
+        schema_with_multiple_fields = StructType(
+            simple_struct.fields
+            + nested_struct.fields
+            + array_in_schema.fields
+            + map_in_schema.fields
+            + nested_map.fields
+            + array_in_map.fields
+            + nested_array_in_map_value.fields
+        )
+
+        schemas = [
+            simple_struct,
+            nested_struct,
+            array_in_schema,
+            map_in_schema,
+            nested_map,
+            nested_array_in_map_value,
+            array_in_map,
+            schema_with_multiple_fields,
+        ]
+
+        for schema in schemas:
+            scala_datatype = 
self.spark._jsparkSession.parseDataType(schema.json())
+            python_datatype = 
_parse_datatype_json_string(scala_datatype.json())
+            assert schema == python_datatype
+            assert schema == _parse_datatype_json_string(schema.json())
+
+    def test_schema_with_collations_on_non_string_types(self):
+        from pyspark.sql.types import _parse_datatype_json_string, 
_COLLATIONS_METADATA_KEY
+
+        collations_on_int_col_json = f"""
+        {{
+          "type": "struct",
+          "fields": [
+            {{
+              "name": "c1",
+              "type": "integer",
+              "nullable": true,
+              "metadata": {{
+                "{_COLLATIONS_METADATA_KEY}": {{
+                  "c1": "icu.UNICODE"
+                }}
+              }}
+            }}
+          ]
+        }}
+        """
+
+        collations_in_array_element_json = f"""
+        {{
+          "type": "struct",
+          "fields": [
+            {{
+              "name": "arrayField",
+              "type": {{
+                  "type": "array",
+                  "elementType": "integer",
+                  "containsNull": true
+              }},
+              "nullable": true,
+              "metadata": {{
+                "{_COLLATIONS_METADATA_KEY}": {{
+                  "arrayField.element": "icu.UNICODE"
+                }}
+              }}
+            }}
+          ]
+        }}
+        """
+
+        collations_on_array_json = f"""
+        {{
+          "type": "struct",
+          "fields": [
+            {{
+              "name": "arrayField",
+              "type": {{
+                  "type": "array",
+                  "elementType": "integer",
+                  "containsNull": true
+              }},
+              "nullable": true,
+              "metadata": {{
+                "{_COLLATIONS_METADATA_KEY}": {{
+                  "arrayField": "icu.UNICODE"
+                }}
+              }}
+            }}
+          ]
+        }}
+        """
+
+        collations_in_nested_map_json = f"""
+        {{
+          "type": "struct",
+          "fields": [
+            {{
+              "name": "nested",
+              "type": {{
+                "type": "struct",
+                "fields": [
+                  {{
+                    "name": "mapField",
+                    "type": {{
+                      "type": "map",
+                      "keyType": "string",
+                      "valueType": "integer",
+                      "valueContainsNull": true
+                    }},
+                    "nullable": true,
+                    "metadata": {{
+                      "{_COLLATIONS_METADATA_KEY}": {{
+                        "mapField.value": "icu.UNICODE"
+                      }}
+                    }}
+                  }}
+                ]
+              }},
+              "nullable": true,
+              "metadata": {{}}
+            }}
+          ]
+        }}
+        """
+
+        self.assertRaises(
+            PySparkTypeError, lambda: 
_parse_datatype_json_string(collations_on_int_col_json)
+        )
+
+        self.assertRaises(
+            PySparkTypeError, lambda: 
_parse_datatype_json_string(collations_in_array_element_json)
+        )
+
+        self.assertRaises(
+            PySparkTypeError, lambda: 
_parse_datatype_json_string(collations_on_array_json)
+        )
+
+        self.assertRaises(
+            PySparkTypeError, lambda: 
_parse_datatype_json_string(collations_in_nested_map_json)
+        )
+
+    def test_schema_with_bad_collations_provider(self):
+        from pyspark.sql.types import _parse_datatype_json_string, 
_COLLATIONS_METADATA_KEY
+
+        schema_json = f"""
+        {{
+          "type": "struct",
+          "fields": [
+            {{
+              "name": "c1",
+              "type": "string",
+              "nullable": "true",
+              "metadata": {{
+                "{_COLLATIONS_METADATA_KEY}": {{
+                  "c1": "badProvider.UNICODE"
+                }}
+              }}
+            }}
+          ]
+        }}
+        """
+
+        self.assertRaises(PySparkValueError, lambda: 
_parse_datatype_json_string(schema_json))
+
     def test_udt(self):
         from pyspark.sql.types import _parse_datatype_json_string, 
_infer_type, _make_type_verifier
 
@@ -915,27 +1143,6 @@ class TypesTestsMixin:
                 self.assertEqual(t(), _parse_datatype_string(k))
         self.assertEqual(IntegerType(), _parse_datatype_string("int"))
         self.assertEqual(StringType(), _parse_datatype_string("string"))
-        self.assertEqual(StringType(), _parse_datatype_string("string collate 
UTF8_BINARY"))
-        self.assertEqual(StringType(), _parse_datatype_string("string COLLATE 
UTF8_BINARY"))
-        self.assertEqual(
-            StringType.fromCollationId(0), _parse_datatype_string("string 
COLLATE   UTF8_BINARY")
-        )
-        self.assertEqual(
-            StringType.fromCollationId(1),
-            _parse_datatype_string("string COLLATE UTF8_BINARY_LCASE"),
-        )
-        self.assertEqual(
-            StringType.fromCollationId(2), _parse_datatype_string("string 
COLLATE UNICODE")
-        )
-        self.assertEqual(
-            StringType.fromCollationId(2), _parse_datatype_string("string 
COLLATE `UNICODE`")
-        )
-        self.assertEqual(
-            StringType.fromCollationId(3), _parse_datatype_string("string 
COLLATE UNICODE_CI")
-        )
-        self.assertEqual(
-            StringType.fromCollationId(3), _parse_datatype_string("string 
COLLATE `UNICODE_CI`")
-        )
         self.assertEqual(CharType(1), _parse_datatype_string("char(1)"))
         self.assertEqual(CharType(10), _parse_datatype_string("char( 10   )"))
         self.assertEqual(CharType(11), _parse_datatype_string("char( 11)"))
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index ed3535e7d4aa..d692fd6f3681 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -255,6 +255,9 @@ class StringType(AtomicType):
     """
 
     collationNames = ["UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", 
"UNICODE_CI"]
+    providerSpark = "spark"
+    providerICU = "icu"
+    providers = [providerSpark, providerICU]
 
     def __init__(self, collation: Optional[str] = None):
         self.collationId = 0 if collation is None else 
self.collationNameToId(collation)
@@ -263,21 +266,32 @@ class StringType(AtomicType):
     def fromCollationId(self, collationId: int) -> "StringType":
         return StringType(StringType.collationNames[collationId])
 
-    def collationIdToName(self) -> str:
-        if self.collationId == 0:
-            return ""
-        else:
-            return " collate %s" % StringType.collationNames[self.collationId]
+    @classmethod
+    def collationIdToName(cls, collationId: int) -> str:
+        return StringType.collationNames[collationId]
 
     @classmethod
     def collationNameToId(cls, collationName: str) -> int:
         return StringType.collationNames.index(collationName)
 
+    @classmethod
+    def collationProvider(cls, collationName: str) -> str:
+        # TODO: do this properly like on the scala side
+        if collationName.startswith("UTF8"):
+            return StringType.providerSpark
+        return StringType.providerICU
+
     def simpleString(self) -> str:
-        return "string" + self.collationIdToName()
+        if self.isUTF8BinaryCollation():
+            return "string"
 
+        return f"string collate ${self.collationIdToName(self.collationId)}"
+
+    # For backwards compatibility and compatibility with other readers all 
string types
+    # are serialized in json as regular strings and the collation info is 
written to
+    # struct field metadata
     def jsonValue(self) -> str:
-        return "string" + self.collationIdToName()
+        return "string"
 
     def __repr__(self) -> str:
         return (
@@ -286,6 +300,9 @@ class StringType(AtomicType):
             else "StringType()"
         )
 
+    def isUTF8BinaryCollation(self) -> bool:
+        return self.collationId == 0
+
 
 class CharType(AtomicType):
     """Char data type
@@ -693,8 +710,16 @@ class ArrayType(DataType):
         }
 
     @classmethod
-    def fromJson(cls, json: Dict[str, Any]) -> "ArrayType":
-        return ArrayType(_parse_datatype_json_value(json["elementType"]), 
json["containsNull"])
+    def fromJson(
+        cls,
+        json: Dict[str, Any],
+        fieldPath: str,
+        collationsMap: Optional[Dict[str, str]],
+    ) -> "ArrayType":
+        elementType = _parse_datatype_json_value(
+            json["elementType"], fieldPath + ".element", collationsMap
+        )
+        return ArrayType(elementType, json["containsNull"])
 
     def needConversion(self) -> bool:
         return self.elementType.needConversion()
@@ -810,10 +835,19 @@ class MapType(DataType):
         }
 
     @classmethod
-    def fromJson(cls, json: Dict[str, Any]) -> "MapType":
+    def fromJson(
+        cls,
+        json: Dict[str, Any],
+        fieldPath: str,
+        collationsMap: Optional[Dict[str, str]],
+    ) -> "MapType":
+        keyType = _parse_datatype_json_value(json["keyType"], fieldPath + 
".key", collationsMap)
+        valueType = _parse_datatype_json_value(
+            json["valueType"], fieldPath + ".value", collationsMap
+        )
         return MapType(
-            _parse_datatype_json_value(json["keyType"]),
-            _parse_datatype_json_value(json["valueType"]),
+            keyType,
+            valueType,
             json["valueContainsNull"],
         )
 
@@ -884,22 +918,89 @@ class StructField(DataType):
         return "StructField('%s', %s, %s)" % (self.name, self.dataType, 
str(self.nullable))
 
     def jsonValue(self) -> Dict[str, Any]:
+        collationMetadata = self.getCollationMetadata()
+        metadata = (
+            self.metadata
+            if not collationMetadata
+            else {**self.metadata, _COLLATIONS_METADATA_KEY: collationMetadata}
+        )
+
         return {
             "name": self.name,
             "type": self.dataType.jsonValue(),
             "nullable": self.nullable,
-            "metadata": self.metadata,
+            "metadata": metadata,
         }
 
     @classmethod
     def fromJson(cls, json: Dict[str, Any]) -> "StructField":
+        metadata = json.get("metadata")
+        collationsMap = {}
+        if metadata and _COLLATIONS_METADATA_KEY in metadata:
+            collationsMap = metadata[_COLLATIONS_METADATA_KEY]
+            for key, value in collationsMap.items():
+                nameParts = value.split(".")
+                assert len(nameParts) == 2
+                provider, name = nameParts[0], nameParts[1]
+                _assert_valid_collation_provider(provider)
+                collationsMap[key] = name
+
+            metadata = {
+                key: value for key, value in metadata.items() if key != 
_COLLATIONS_METADATA_KEY
+            }
+
         return StructField(
             json["name"],
-            _parse_datatype_json_value(json["type"]),
+            _parse_datatype_json_value(json["type"], json["name"], 
collationsMap),
             json.get("nullable", True),
-            json.get("metadata"),
+            metadata,
         )
 
+    def getCollationsMap(self, metadata: Dict[str, Any]) -> Dict[str, str]:
+        if not metadata or _COLLATIONS_METADATA_KEY not in metadata:
+            return {}
+
+        collationMetadata: Dict[str, str] = metadata[_COLLATIONS_METADATA_KEY]
+        collationsMap: Dict[str, str] = {}
+
+        for key, value in collationMetadata.items():
+            nameParts = value.split(".")
+            assert len(nameParts) == 2
+            provider, name = nameParts[0], nameParts[1]
+            _assert_valid_collation_provider(provider)
+            collationsMap[key] = name
+
+        return collationsMap
+
+    def getCollationMetadata(self) -> Dict[str, str]:
+        def visitRecursively(dt: DataType, fieldPath: str) -> None:
+            if isinstance(dt, ArrayType):
+                processDataType(dt.elementType, fieldPath + ".element")
+            elif isinstance(dt, MapType):
+                processDataType(dt.keyType, fieldPath + ".key")
+                processDataType(dt.valueType, fieldPath + ".value")
+            elif isinstance(dt, StringType) and self._isCollatedString(dt):
+                collationMetadata[fieldPath] = self.schemaCollationValue(dt)
+
+        def processDataType(dt: DataType, fieldPath: str) -> None:
+            if self._isCollatedString(dt):
+                collationMetadata[fieldPath] = self.schemaCollationValue(dt)
+            else:
+                visitRecursively(dt, fieldPath)
+
+        collationMetadata: Dict[str, str] = {}
+        visitRecursively(self.dataType, self.name)
+        return collationMetadata
+
+    def _isCollatedString(self, dt: DataType) -> bool:
+        return isinstance(dt, StringType) and not dt.isUTF8BinaryCollation()
+
+    def schemaCollationValue(self, dt: DataType) -> str:
+        assert isinstance(dt, StringType)
+        collationName = StringType.collationIdToName(dt.collationId)
+        provider = StringType.collationProvider(collationName)
+        return f"{provider}.{collationName}"
+
     def needConversion(self) -> bool:
         return self.dataType.needConversion()
 
@@ -1561,13 +1662,14 @@ _all_complex_types: Dict[str, Type[Union[ArrayType, 
MapType, StructType]]] = dic
     (v.typeName(), v) for v in _complex_types
 )
 
-_COLLATED_STRING = re.compile(r"string\s+collate\s+([\w_]+|`[\w_]`)")
 _LENGTH_CHAR = re.compile(r"char\(\s*(\d+)\s*\)")
 _LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)")
 _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
 _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to 
(day|hour|minute|second))?")
 _INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?")
 
+_COLLATIONS_METADATA_KEY = "__COLLATIONS"
+
 
 def _drop_metadata(d: Union[DataType, StructField]) -> Union[DataType, 
StructField]:
     assert isinstance(d, (DataType, StructField))
@@ -1715,9 +1817,17 @@ def _parse_datatype_json_string(json_string: str) -> 
DataType:
     return _parse_datatype_json_value(json.loads(json_string))
 
 
-def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType:
+def _parse_datatype_json_value(
+    json_value: Union[dict, str],
+    fieldPath: str = "",
+    collationsMap: Optional[Dict[str, str]] = None,
+) -> DataType:
     if not isinstance(json_value, dict):
         if json_value in _all_atomic_types.keys():
+            if collationsMap is not None and fieldPath in collationsMap:
+                _assert_valid_type_for_collation(fieldPath, json_value, 
collationsMap)
+                collation_name = collationsMap[fieldPath]
+                return StringType(collation_name)
             return _all_atomic_types[json_value]()
         elif json_value == "decimal":
             return DecimalType()
@@ -1742,9 +1852,6 @@ def _parse_datatype_json_value(json_value: Union[dict, 
str]) -> DataType:
             return YearMonthIntervalType(first_field, second_field)
         elif json_value == "interval":
             return CalendarIntervalType()
-        elif _COLLATED_STRING.match(json_value):
-            m = _COLLATED_STRING.match(json_value)
-            return StringType(m.group(1))  # type: ignore[union-attr]
         elif _LENGTH_CHAR.match(json_value):
             m = _LENGTH_CHAR.match(json_value)
             return CharType(int(m.group(1)))  # type: ignore[union-attr]
@@ -1759,7 +1866,15 @@ def _parse_datatype_json_value(json_value: Union[dict, 
str]) -> DataType:
     else:
         tpe = json_value["type"]
         if tpe in _all_complex_types:
-            return _all_complex_types[tpe].fromJson(json_value)
+            if collationsMap is not None and fieldPath in collationsMap:
+                _assert_valid_type_for_collation(fieldPath, tpe, collationsMap)
+
+            complex_type = _all_complex_types[tpe]
+            if complex_type is ArrayType:
+                return ArrayType.fromJson(json_value, fieldPath, collationsMap)
+            elif complex_type is MapType:
+                return MapType.fromJson(json_value, fieldPath, collationsMap)
+            return StructType.fromJson(json_value)
         elif tpe == "udt":
             return UserDefinedType.fromJson(json_value)
         else:
@@ -1769,6 +1884,27 @@ def _parse_datatype_json_value(json_value: Union[dict, 
str]) -> DataType:
             )
 
 
+def _assert_valid_type_for_collation(
+    fieldPath: str, fieldType: Any, collationMap: Dict[str, str]
+) -> None:
+    if fieldPath in collationMap and fieldType != "string":
+        raise PySparkTypeError(
+            error_class="INVALID_JSON_DATA_TYPE_FOR_COLLATIONS",
+            message_parameters={"jsonType": fieldType},
+        )
+
+
+def _assert_valid_collation_provider(provider: str) -> None:
+    if provider.lower() not in StringType.providers:
+        raise PySparkValueError(
+            error_class="COLLATION_INVALID_PROVIDER",
+            message_parameters={
+                "provider": provider,
+                "supportedProviders": ", ".join(StringType.providers),
+            },
+        )
+
+
 # Mapping Python types to Spark SQL DataType
 _type_mappings = {
     type(None): NullType,
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 16cf6224ce27..0d53f5ae7902 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -117,7 +117,8 @@ object DataType {
   private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r
   private val CHAR_TYPE = """char\(\s*(\d+)\s*\)""".r
   private val VARCHAR_TYPE = """varchar\(\s*(\d+)\s*\)""".r
-  private val COLLATED_STRING_TYPE = 
"""string\s+collate\s+([\w_]+|`[\w_]`)""".r
+
+  val COLLATIONS_METADATA_KEY = "__COLLATIONS"
 
   def fromDDL(ddl: String): DataType = {
     parseTypeWithFallback(
@@ -182,9 +183,6 @@ object DataType {
   /** Given the string representation of a type, return its DataType */
   private def nameToType(name: String): DataType = {
     name match {
-      case COLLATED_STRING_TYPE(collation) =>
-        val collationId = CollationFactory.collationNameToId(collation)
-        StringType(collationId)
       case "decimal" => DecimalType.USER_DEFAULT
       case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, 
scale.toInt)
       case CHAR_TYPE(length) => CharType(length.toInt)
@@ -208,26 +206,40 @@ object DataType {
   }
 
   // NOTE: Map fields must be sorted in alphabetical order to keep consistent 
with the Python side.
-  private[sql] def parseDataType(json: JValue): DataType = json match {
+  private[sql] def parseDataType(
+      json: JValue,
+      fieldPath: String = "",
+      collationsMap: Map[String, String] = Map.empty): DataType = json match {
     case JString(name) =>
-      nameToType(name)
+      collationsMap.get(fieldPath) match {
+        case Some(collation) =>
+          assertValidTypeForCollations(fieldPath, name, collationsMap)
+          stringTypeWithCollation(collation)
+        case _ => nameToType(name)
+      }
 
     case JSortedObject(
     ("containsNull", JBool(n)),
     ("elementType", t: JValue),
     ("type", JString("array"))) =>
-      ArrayType(parseDataType(t), n)
+      assertValidTypeForCollations(fieldPath, "array", collationsMap)
+      val elementType = parseDataType(t, fieldPath + ".element", collationsMap)
+      ArrayType(elementType, n)
 
     case JSortedObject(
     ("keyType", k: JValue),
     ("type", JString("map")),
     ("valueContainsNull", JBool(n)),
     ("valueType", v: JValue)) =>
-      MapType(parseDataType(k), parseDataType(v), n)
+      assertValidTypeForCollations(fieldPath, "map", collationsMap)
+      val keyType = parseDataType(k, fieldPath + ".key", collationsMap)
+      val valueType = parseDataType(v, fieldPath + ".value", collationsMap)
+      MapType(keyType, valueType, n)
 
     case JSortedObject(
     ("fields", JArray(fields)),
     ("type", JString("struct"))) =>
+      assertValidTypeForCollations(fieldPath, "struct", collationsMap)
       StructType(fields.map(parseStructField))
 
     // Scala/Java UDT
@@ -253,11 +265,18 @@ object DataType {
 
   private def parseStructField(json: JValue): StructField = json match {
     case JSortedObject(
-    ("metadata", metadata: JObject),
+    ("metadata", JObject(metadataFields)),
     ("name", JString(name)),
     ("nullable", JBool(nullable)),
     ("type", dataType: JValue)) =>
-      StructField(name, parseDataType(dataType), nullable, 
Metadata.fromJObject(metadata))
+      val collationsMap = getCollationsMap(metadataFields)
+      val metadataWithoutCollations =
+        JObject(metadataFields.filterNot(_._1 == COLLATIONS_METADATA_KEY))
+      StructField(
+        name,
+        parseDataType(dataType, name, collationsMap),
+        nullable,
+        Metadata.fromJObject(metadataWithoutCollations))
     // Support reading schema when 'metadata' is missing.
     case JSortedObject(
     ("name", JString(name)),
@@ -274,6 +293,41 @@ object DataType {
       messageParameters = Map("other" -> compact(render(other))))
   }
 
+  private def assertValidTypeForCollations(
+      fieldPath: String,
+      fieldType: String,
+      collationMap: Map[String, String]): Unit = {
+    if (collationMap.contains(fieldPath) && fieldType != "string") {
+      throw new SparkIllegalArgumentException(
+        errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS",
+        messageParameters = Map("jsonType" -> fieldType))
+    }
+  }
+
+  /**
+   * Returns a map of field path to collation name.
+   */
+  private def getCollationsMap(metadataFields: List[JField]): Map[String, 
String] = {
+    val collationsJsonOpt = metadataFields.find(_._1 == 
COLLATIONS_METADATA_KEY).map(_._2)
+    collationsJsonOpt match {
+      case Some(JObject(fields)) =>
+        fields.collect {
+          case (fieldPath, JString(collation)) =>
+            collation.split("\\.", 2) match {
+              case Array(provider: String, collationName: String) =>
+                CollationFactory.assertValidProvider(provider)
+                fieldPath -> collationName
+            }
+        }.toMap
+
+      case _ => Map.empty
+    }
+  }
+
+  private def stringTypeWithCollation(collationName: String): StringType = {
+    StringType(CollationFactory.collationNameToId(collationName))
+  }
+
   protected[types] def buildFormattedString(
       dataType: DataType,
       prefix: String,
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
index 74c714ff63f4..b8dadbc9e1dc 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.types
 
+import org.json4s.JsonAST.{JString, JValue}
+
 import org.apache.spark.annotation.Stable
 import org.apache.spark.sql.catalyst.util.CollationFactory
 
@@ -61,6 +63,11 @@ class StringType private(val collationId: Int) extends 
AtomicType with Serializa
     if (isUTF8BinaryCollation) "string"
     else s"string collate 
${CollationFactory.fetchCollation(collationId).collationName}"
 
+  // Due to backwards compatibility and compatibility with other readers
+  // all string types are serialized in json as regular strings and
+  // the collation information is written to struct field metadata
+  override def jsonValue: JValue = JString("string")
+
   override def equals(obj: Any): Boolean =
     obj.isInstanceOf[StringType] && obj.asInstanceOf[StringType].collationId 
== collationId
 
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
index 66f9557db213..3ff96fea9ee0 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
@@ -17,11 +17,15 @@
 
 package org.apache.spark.sql.types
 
+import scala.collection.mutable
+
+import org.json4s.{JObject, JString}
 import org.json4s.JsonAST.JValue
 import org.json4s.JsonDSL._
 
+import org.apache.spark.SparkException
 import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.catalyst.util.{QuotingUtils, StringConcat}
+import org.apache.spark.sql.catalyst.util.{CollationFactory, QuotingUtils, 
StringConcat}
 import 
org.apache.spark.sql.catalyst.util.ResolveDefaultColumnsUtils.{CURRENT_DEFAULT_COLUMN_METADATA_KEY,
 EXISTS_DEFAULT_COLUMN_METADATA_KEY}
 import org.apache.spark.util.SparkSchemaUtils
 
@@ -63,7 +67,61 @@ case class StructField(
     ("name" -> name) ~
       ("type" -> dataType.jsonValue) ~
       ("nullable" -> nullable) ~
-      ("metadata" -> metadata.jsonValue)
+      ("metadata" -> metadataJson)
+  }
+
+  private def metadataJson: JValue = {
+    val metadataJsonValue = metadata.jsonValue
+    metadataJsonValue match {
+      case JObject(fields) if collationMetadata.nonEmpty =>
+        val collationFields = collationMetadata.map(kv => kv._1 -> 
JString(kv._2)).toList
+        JObject(fields :+ (DataType.COLLATIONS_METADATA_KEY -> 
JObject(collationFields)))
+
+      case _ => metadataJsonValue
+    }
+  }
+
+  /** Map of field path to collation name. */
+  private lazy val collationMetadata: Map[String, String] = {
+    val fieldToCollationMap = mutable.Map[String, String]()
+
+    def visitRecursively(dt: DataType, path: String): Unit = dt match {
+      case at: ArrayType =>
+        processDataType(at.elementType, path + ".element")
+
+      case mt: MapType =>
+        processDataType(mt.keyType, path + ".key")
+        processDataType(mt.valueType, path + ".value")
+
+      case st: StringType if isCollatedString(st) =>
+        fieldToCollationMap(path) = schemaCollationValue(st)
+
+      case _ =>
+    }
+
+    def processDataType(dt: DataType, path: String): Unit = {
+      if (isCollatedString(dt)) {
+        fieldToCollationMap(path) = schemaCollationValue(dt)
+      } else {
+        visitRecursively(dt, path)
+      }
+    }
+
+    visitRecursively(dataType, name)
+    fieldToCollationMap.toMap
+  }
+
+  private def isCollatedString(dt: DataType): Boolean = dt match {
+    case st: StringType => !st.isUTF8BinaryCollation
+    case _ => false
+  }
+
+  private def schemaCollationValue(dt: DataType): String = dt match {
+    case st: StringType =>
+      val collation = CollationFactory.fetchCollation(st.collationId)
+      collation.identifier().toStringWithoutVersion()
+    case _ =>
+      throw SparkException.internalError(s"Unexpected data type $dt")
   }
 
   /**
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 3293957282e2..721d7c25d17b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -23,11 +23,13 @@ import org.apache.spark.{SparkException, SparkFunSuite, 
SparkIllegalArgumentExce
 import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, 
caseSensitiveResolution}
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.catalyst.util.StringConcat
+import org.apache.spark.sql.catalyst.util.{CollationFactory, StringConcat}
 import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, 
yearMonthIntervalTypes}
 
 class DataTypeSuite extends SparkFunSuite {
 
+  private val UNICODE_COLLATION_ID = 
CollationFactory.collationNameToId("UNICODE")
+
   test("construct an ArrayType") {
     val array = ArrayType(StringType)
 
@@ -712,4 +714,181 @@ class DataTypeSuite extends SparkFunSuite {
 
     assert(result === expected)
   }
+
+  test("schema with collation should not change during ser/de") {
+    val simpleStruct = StructType(
+      StructField("c1", StringType(UNICODE_COLLATION_ID)) :: Nil)
+
+    val nestedStruct = StructType(
+      StructField("nested", simpleStruct) :: Nil)
+
+    val caseInsensitiveNames = StructType(
+      StructField("c1", StringType(UNICODE_COLLATION_ID)) ::
+      StructField("C1", StringType(UNICODE_COLLATION_ID)) :: Nil)
+
+    val specialCharsInName = StructType(
+      StructField("c1.*23?", StringType(UNICODE_COLLATION_ID)) :: Nil)
+
+    val arrayInSchema = StructType(
+      StructField("arrayField", ArrayType(StringType(UNICODE_COLLATION_ID))) 
:: Nil)
+
+    val mapInSchema = StructType(
+      StructField("mapField",
+        MapType(StringType(UNICODE_COLLATION_ID), 
StringType(UNICODE_COLLATION_ID))) :: Nil)
+
+    val mapWithKeyInNameInSchema = StructType(
+      StructField("name.key", StringType) ::
+      StructField("name",
+        MapType(StringType(UNICODE_COLLATION_ID), 
StringType(UNICODE_COLLATION_ID))) :: Nil)
+
+    val arrayInMapInNestedSchema = StructType(
+      StructField("arrInMap",
+        MapType(StringType(UNICODE_COLLATION_ID),
+        ArrayType(StringType(UNICODE_COLLATION_ID)))) :: Nil)
+
+    val nestedArrayInMap = StructType(
+      StructField("nestedArrayInMap",
+        ArrayType(MapType(StringType(UNICODE_COLLATION_ID),
+          ArrayType(ArrayType(StringType(UNICODE_COLLATION_ID)))))) :: Nil)
+
+    val schemaWithMultipleFields = StructType(
+      simpleStruct.fields ++ nestedStruct.fields ++ arrayInSchema.fields ++ 
mapInSchema.fields ++
+        mapWithKeyInNameInSchema ++ arrayInMapInNestedSchema.fields ++ 
nestedArrayInMap.fields)
+
+    Seq(
+      simpleStruct, caseInsensitiveNames, specialCharsInName, nestedStruct, 
arrayInSchema,
+      mapInSchema, mapWithKeyInNameInSchema, nestedArrayInMap, 
arrayInMapInNestedSchema,
+      schemaWithMultipleFields)
+      .foreach { schema =>
+        val json = schema.json
+        val parsed = DataType.fromJson(json)
+        assert(parsed === schema)
+      }
+  }
+
+  test("non string field has collation metadata") {
+    val json =
+      s"""
+         |{
+         |  "type": "struct",
+         |  "fields": [
+         |    {
+         |      "name": "c1",
+         |      "type": "integer",
+         |      "nullable": true,
+         |      "metadata": {
+         |        "${DataType.COLLATIONS_METADATA_KEY}": {
+         |          "c1": "icu.UNICODE"
+         |        }
+         |      }
+         |    }
+         |  ]
+         |}
+         |""".stripMargin
+
+    checkError(
+      exception = intercept[SparkIllegalArgumentException] {
+        DataType.fromJson(json)
+      },
+      errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS",
+      parameters = Map("jsonType" -> "integer")
+    )
+  }
+
+  test("non string field in map key has collation metadata") {
+    val json =
+      s"""
+         |{
+         |  "type": "struct",
+         |  "fields": [
+         |    {
+         |      "name": "mapField",
+         |      "type": {
+         |        "type": "map",
+         |        "keyType": "string",
+         |        "valueType": "integer",
+         |        "valueContainsNull": true
+         |      },
+         |      "nullable": true,
+         |      "metadata": {
+         |        "${DataType.COLLATIONS_METADATA_KEY}": {
+         |          "mapField.value": "icu.UNICODE"
+         |        }
+         |      }
+         |    }
+         |  ]
+         |}
+         |""".stripMargin
+
+    checkError(
+      exception = intercept[SparkIllegalArgumentException] {
+        DataType.fromJson(json)
+      },
+      errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS",
+      parameters = Map("jsonType" -> "integer")
+    )
+  }
+
+  test("map field has collation metadata") {
+    val json =
+      s"""
+         |{
+         |  "type": "struct",
+         |  "fields": [
+         |    {
+         |      "name": "mapField",
+         |      "type": {
+         |        "type": "map",
+         |        "keyType": "string",
+         |        "valueType": "integer",
+         |        "valueContainsNull": true
+         |      },
+         |      "nullable": true,
+         |      "metadata": {
+         |        "${DataType.COLLATIONS_METADATA_KEY}": {
+         |          "mapField": "icu.UNICODE"
+         |        }
+         |      }
+         |    }
+         |  ]
+         |}
+         |""".stripMargin
+
+    checkError(
+      exception = intercept[SparkIllegalArgumentException] {
+        DataType.fromJson(json)
+      },
+      errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS",
+      parameters = Map("jsonType" -> "map")
+    )
+  }
+
+  test("non existing collation provider") {
+    val json =
+      s"""
+         |{
+         |  "type": "struct",
+         |  "fields": [
+         |    {
+         |      "name": "c1",
+         |      "type": "string",
+         |      "nullable": true,
+         |      "metadata": {
+         |        "${DataType.COLLATIONS_METADATA_KEY}": {
+         |          "c1": "badProvider.UNICODE"
+         |        }
+         |      }
+         |    }
+         |  ]
+         |}
+         |""".stripMargin
+
+    checkError(
+      exception = intercept[SparkException] {
+        DataType.fromJson(json)
+      },
+      errorClass = "COLLATION_INVALID_PROVIDER",
+      parameters = Map("provider" -> "badProvider", "supportedProviders" -> 
"spark, icu")
+    )
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
index c165ab1bf61b..bd0685e10832 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.types
 
+import com.fasterxml.jackson.databind.ObjectMapper
+
 import org.apache.spark.{SparkException, SparkFunSuite, 
SparkIllegalArgumentException}
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, 
caseSensitiveResolution}
@@ -36,6 +38,10 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper {
 
   private val s = StructType.fromDDL("a INT, b STRING")
 
+  private val UNICODE_COLLATION = "UNICODE"
+  private val UTF8_BINARY_LCASE_COLLATION = "UTF8_BINARY_LCASE"
+  private val mapper = new ObjectMapper()
+
   test("lookup a single missing field should output existing fields") {
     checkError(
       exception = intercept[SparkIllegalArgumentException](s("c")),
@@ -606,4 +612,181 @@ class StructTypeSuite extends SparkFunSuite with 
SQLHelper {
         "b STRING NOT NULL,c STRING COMMENT 'nullable comment'")
     assert(fromDDL(struct.toDDL) === struct)
   }
+
+  test("simple struct with collations to json") {
+    val simpleStruct = StructType(
+      StructField("c1", StringType(UNICODE_COLLATION)) :: Nil)
+
+    val expectedJson =
+      s"""
+         |{
+         |  "type": "struct",
+         |  "fields": [
+         |    {
+         |      "name": "c1",
+         |      "type": "string",
+         |      "nullable": true,
+         |      "metadata": {
+         |        "${DataType.COLLATIONS_METADATA_KEY}": {
+         |          "c1": "icu.$UNICODE_COLLATION"
+         |        }
+         |      }
+         |    }
+         |  ]
+         |}
+         |""".stripMargin
+
+    assert(mapper.readTree(simpleStruct.json) == mapper.readTree(expectedJson))
+  }
+
+  test("nested struct with collations to json") {
+    val nestedStruct = StructType(
+      StructField("nested", StructType(
+        StructField("c1", StringType(UTF8_BINARY_LCASE_COLLATION)) :: Nil)) :: 
Nil)
+
+    val expectedJson =
+      s"""
+         |{
+         |  "type": "struct",
+         |  "fields": [
+         |    {
+         |      "name": "nested",
+         |      "type": {
+         |        "type": "struct",
+         |        "fields": [
+         |          {
+         |            "name": "c1",
+         |            "type": "string",
+         |            "nullable": true,
+         |            "metadata": {
+         |              "${DataType.COLLATIONS_METADATA_KEY}": {
+         |                "c1": "spark.$UTF8_BINARY_LCASE_COLLATION"
+         |              }
+         |            }
+         |          }
+         |        ]
+         |      },
+         |      "nullable": true,
+         |      "metadata": {}
+         |    }
+         |  ]
+         |}
+         |""".stripMargin
+
+    assert(mapper.readTree(nestedStruct.json) == mapper.readTree(expectedJson))
+  }
+
+  test("array with collations in schema to json") {
+    val arrayInSchema = StructType(
+      StructField("arrayField", ArrayType(StringType(UNICODE_COLLATION))) :: 
Nil)
+
+    val expectedJson =
+      s"""
+         |{
+         |  "type": "struct",
+         |  "fields": [
+         |    {
+         |      "name": "arrayField",
+         |      "type": {
+         |        "type": "array",
+         |        "elementType": "string",
+         |        "containsNull": true
+         |      },
+         |      "nullable": true,
+         |      "metadata": {
+         |        "${DataType.COLLATIONS_METADATA_KEY}": {
+         |          "arrayField.element": "icu.$UNICODE_COLLATION"
+         |        }
+         |      }
+         |    }
+         |  ]
+         |}
+         |""".stripMargin
+
+    assert(mapper.readTree(arrayInSchema.json) == 
mapper.readTree(expectedJson))
+  }
+
+  test("map with collations in schema to json") {
+    val arrayInSchema = StructType(
+      StructField("mapField",
+        MapType(StringType(UNICODE_COLLATION), StringType(UNICODE_COLLATION))) 
:: Nil)
+
+    val expectedJson =
+      s"""
+         |{
+         |  "type": "struct",
+         |  "fields": [
+         |    {
+         |      "name": "mapField",
+         |      "type": {
+         |        "type": "map",
+         |        "keyType": "string",
+         |        "valueType": "string",
+         |        "valueContainsNull": true
+         |      },
+         |      "nullable": true,
+         |      "metadata": {
+         |        "${DataType.COLLATIONS_METADATA_KEY}": {
+         |          "mapField.key": "icu.$UNICODE_COLLATION",
+         |          "mapField.value": "icu.$UNICODE_COLLATION"
+         |        }
+         |      }
+         |    }
+         |  ]
+         |}
+         |""".stripMargin
+
+    assert(mapper.readTree(arrayInSchema.json) == 
mapper.readTree(expectedJson))
+  }
+
+  test("nested array with collations in map to json" ) {
+    val mapWithNestedArray = StructType(
+      StructField("column", ArrayType(MapType(
+        StringType(UNICODE_COLLATION),
+        ArrayType(ArrayType(ArrayType(StringType(UNICODE_COLLATION))))))) :: 
Nil)
+
+    val expectedJson =
+      s"""
+         |{
+         |  "type": "struct",
+         |  "fields": [
+         |    {
+         |      "name": "column",
+         |      "type": {
+         |        "type": "array",
+         |        "elementType": {
+         |          "type": "map",
+         |          "keyType": "string",
+         |          "valueType": {
+         |            "type": "array",
+         |            "elementType": {
+         |              "type": "array",
+         |              "elementType": {
+         |                "type": "array",
+         |                "elementType": "string",
+         |                "containsNull": true
+         |              },
+         |              "containsNull": true
+         |            },
+         |            "containsNull": true
+         |          },
+         |          "valueContainsNull": true
+         |        },
+         |        "containsNull": true
+         |      },
+         |      "nullable": true,
+         |      "metadata": {
+         |        "${DataType.COLLATIONS_METADATA_KEY}": {
+         |          "column.element.key": "icu.$UNICODE_COLLATION",
+         |          "column.element.value.element.element.element": 
"icu.$UNICODE_COLLATION"
+         |        }
+         |      }
+         |    }
+         |  ]
+         |}
+         |""".stripMargin
+
+    assert(
+      mapper.readTree(mapWithNestedArray.json) == 
mapper.readTree(expectedJson))
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala
index 5c816c5cddc7..7c84e3e2d018 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala
@@ -527,7 +527,7 @@ class StreamingDeduplicationSuite extends 
StateStoreMetricsTest {
       ex.getCause.asInstanceOf[SparkUnsupportedOperationException],
       errorClass = "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY",
       parameters = Map(
-        "schema" -> ".+\"type\":\"string collate UTF8_BINARY_LCASE\".+"
+        "schema" -> ".+\"str\":\"spark.UTF8_BINARY_LCASE\".+"
       ),
       matchPVals = true
     )
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 227b50509afe..8b761c24b604 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -1425,7 +1425,7 @@ class StreamingQuerySuite extends StreamTest with 
BeforeAndAfter with Logging wi
         ex.getCause.asInstanceOf[SparkUnsupportedOperationException],
         errorClass = "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY",
         parameters = Map(
-          "schema" -> ".+\"type\":\"string collate UTF8_BINARY_LCASE\".+"
+          "schema" -> ".+\"c1\":\"spark.UTF8_BINARY_LCASE\".+"
         ),
         matchPVals = true
       )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to