This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 804b2a416781 [SPARK-47506][SQL] Add support to all file source formats 
for collated data types
804b2a416781 is described below

commit 804b2a4167813ac33f5d2e61898483a66c389059
Author: Stefan Kandic <stefan.kan...@databricks.com>
AuthorDate: Mon Mar 25 10:05:01 2024 +0500

    [SPARK-47506][SQL] Add support to all file source formats for collated data 
types
    
    ### What changes were proposed in this pull request?
    
    Adding support and tests for collated types in all the file sources 
currently supported by Spark, including:
     - parquet
     - json
     - csv
     - orc
     - text
    
    Important to note is that collations metadata will only be preserved if 
these file sources are specified via the [CREATE TABLE USING 
DATA_SOURCE](https://spark.apache.org/docs/latest/sql-ref-syntax-ddl-create-table-datasource.html)
 api. Just using the dataframe api to directly write to a file will not 
preserve collation metadata (except in the case of parquet because it saves the 
schema in the file itself).
    
    ### Why are the changes needed?
    
    To have collations be compatible with all file sources users can choose 
from.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, users can now create tables with collations using all supported file 
sources.
    
    ### How was this patch tested?
    
    New unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #45641 from stefankandic/fileSources.
    
    Authored-by: Stefan Kandic <stefan.kan...@databricks.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../spark/sql/catalyst/json/JacksonGenerator.scala |  2 +-
 .../execution/datasources/orc/OrcSerializer.scala  |  2 +-
 .../sql/execution/datasources/orc/OrcUtils.scala   |  4 ++
 .../datasources/text/TextFileFormat.scala          |  2 +-
 .../org/apache/spark/sql/CollationSuite.scala      | 49 ++++++++++++++++++----
 5 files changed, 47 insertions(+), 12 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
index e01457ff1025..c2c6117e1e3a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
@@ -137,7 +137,7 @@ class JacksonGenerator(
       (row: SpecializedGetters, ordinal: Int) =>
         gen.writeNumber(row.getDouble(ordinal))
 
-    case StringType =>
+    case _: StringType =>
       (row: SpecializedGetters, ordinal: Int) =>
         gen.writeString(row.getUTF8String(ordinal).toString)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
index 5ed73c3f78b1..75e3e13b0f7e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
@@ -130,7 +130,7 @@ class OrcSerializer(dataSchema: StructType) {
 
 
     // Don't reuse the result object for string and binary as it would cause 
extra data copy.
-    case StringType => (getter, ordinal) =>
+    case _: StringType => (getter, ordinal) =>
       new Text(getter.getUTF8String(ordinal).getBytes)
 
     case BinaryType => (getter, ordinal) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
index 15fa2f88e128..24943b37d059 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -305,6 +305,10 @@ object OrcUtils extends Logging {
           val typeDesc = new 
TypeDescription(TypeDescription.Category.TIMESTAMP)
           typeDesc.setAttribute(CATALYST_TYPE_ATTRIBUTE_NAME, t.typeName)
           Some(typeDesc)
+        case _: StringType =>
+          val typeDesc = new TypeDescription(TypeDescription.Category.STRING)
+          typeDesc.setAttribute(CATALYST_TYPE_ATTRIBUTE_NAME, 
StringType.typeName)
+          Some(typeDesc)
         case _ => None
       }
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index e675f70e2a0d..caa4e3ed386b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -138,6 +138,6 @@ class TextFileFormat extends TextBasedFileFormat with 
DataSourceRegister {
   }
 
   override def supportDataType(dataType: DataType): Boolean =
-    dataType == StringType
+    dataType.isInstanceOf[StringType]
 }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index 146ba63cf402..f0b51a5b2c19 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -36,6 +36,10 @@ import org.apache.spark.sql.types.{MapType, StringType, 
StructField, StructType}
 class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
   protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName
 
+  private val collationPreservingSources = Seq("parquet")
+  private val collationNonPreservingSources = Seq("orc", "csv", "json", "text")
+  private val allFileBasedDataSources = collationPreservingSources ++  
collationNonPreservingSources
+
   test("collate returns proper type") {
     Seq("utf8_binary", "utf8_binary_lcase", "unicode", "unicode_ci").foreach { 
collationName =>
       checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa"))
@@ -424,22 +428,49 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
   }
 
   test("create table with collation") {
-    val tableName = "parquet_dummy_tbl"
+    val tableName = "dummy_tbl"
     val collationName = "UTF8_BINARY_LCASE"
     val collationId = CollationFactory.collationNameToId(collationName)
 
-    withTable(tableName) {
-      sql(
+    allFileBasedDataSources.foreach { format =>
+      withTable(tableName) {
+        sql(
         s"""
-           |CREATE TABLE $tableName (c1 STRING COLLATE $collationName)
-           |USING PARQUET
+           |CREATE TABLE $tableName (
+           |  c1 STRING COLLATE $collationName
+           |)
+           |USING $format
            |""".stripMargin)
 
-      sql(s"INSERT INTO $tableName VALUES ('aaa')")
-      sql(s"INSERT INTO $tableName VALUES ('AAA')")
+        sql(s"INSERT INTO $tableName VALUES ('aaa')")
+        sql(s"INSERT INTO $tableName VALUES ('AAA')")
 
-      checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), 
Seq(Row(collationName)))
-      assert(sql(s"select c1 FROM $tableName").schema.head.dataType == 
StringType(collationId))
+        checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), 
Seq(Row(collationName)))
+        assert(sql(s"select c1 FROM $tableName").schema.head.dataType == 
StringType(collationId))
+      }
+    }
+  }
+
+  test("write collated data to different data sources with dataframe api") {
+    val collationName = "UNICODE_CI"
+
+    allFileBasedDataSources.foreach { format =>
+      withTempPath { path =>
+        val df = sql(s"SELECT c COLLATE $collationName AS c FROM VALUES 
('aaa') AS data(c)")
+        df.write.format(format).save(path.getAbsolutePath)
+
+        val readback = spark.read.format(format).load(path.getAbsolutePath)
+        val readbackCollation = if 
(collationPreservingSources.contains(format)) {
+          collationName
+        } else {
+          "UTF8_BINARY"
+        }
+
+        checkAnswer(readback, Row("aaa"))
+        checkAnswer(
+          readback.selectExpr(s"collation(${readback.columns.head})"),
+          Row(readbackCollation))
+      }
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to