This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e6c914f63079 [SPARK-48157][SQL] Add collation support for CSV 
expressions
e6c914f63079 is described below

commit e6c914f630793992eba7a409ec2cd061f385ce02
Author: Uros Bojanic <157381213+uros...@users.noreply.github.com>
AuthorDate: Tue May 14 14:17:45 2024 +0800

    [SPARK-48157][SQL] Add collation support for CSV expressions
    
    ### What changes were proposed in this pull request?
    Introduce collation awareness for CSV expressions: from_csv, schema_of_csv, 
to_csv.
    
    ### Why are the changes needed?
    Add collation support for CSV expressions in Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users should now be able to use collated strings within arguments for 
CSV functions: from_csv, schema_of_csv, to_csv.
    
    ### How was this patch tested?
    E2e sql tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #46504 from uros-db/csv-expressions.
    
    Authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/expressions/csvExpressions.scala  |   7 +-
 .../spark/sql/CollationSQLExpressionsSuite.scala   | 112 +++++++++++++++++++++
 2 files changed, 116 insertions(+), 3 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index 4714fc1ded9c..cb10440c4832 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.util.TypeUtils._
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -146,7 +147,7 @@ case class CsvToStructs(
     converter(parser.parse(csv))
   }
 
-  override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
+  override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: 
Nil
 
   override def prettyName: String = "from_csv"
 
@@ -177,7 +178,7 @@ case class SchemaOfCsv(
     child = child,
     options = ExprUtils.convertToMapData(options))
 
-  override def dataType: DataType = StringType
+  override def dataType: DataType = SQLConf.get.defaultStringType
 
   override def nullable: Boolean = false
 
@@ -300,7 +301,7 @@ case class StructsToCsv(
     (row: Any) => 
UTF8String.fromString(gen.writeToString(row.asInstanceOf[InternalRow]))
   }
 
-  override def dataType: DataType = StringType
+  override def dataType: DataType = SQLConf.get.defaultStringType
 
   override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
     copy(timeZoneId = Option(timeZoneId))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 22b29154cd78..f8b3548b956c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -313,6 +313,118 @@ class CollationSQLExpressionsSuite
     })
   }
 
+  test("Support CsvToStructs csv expression with collation") {
+    case class CsvToStructsTestCase(
+     input: String,
+     collationName: String,
+     schema: String,
+     options: String,
+     result: Row,
+     structFields: Seq[StructField]
+    )
+
+    val testCases = Seq(
+      CsvToStructsTestCase("1", "UTF8_BINARY", "'a INT'", "",
+        Row(1), Seq(
+          StructField("a", IntegerType, nullable = true)
+        )),
+      CsvToStructsTestCase("true, 0.8", "UTF8_BINARY_LCASE", "'A BOOLEAN, B 
DOUBLE'", "",
+        Row(true, 0.8), Seq(
+          StructField("A", BooleanType, nullable = true),
+          StructField("B", DoubleType, nullable = true)
+        )),
+      CsvToStructsTestCase("\"Spark\"", "UNICODE", "'a STRING'", "",
+        Row("Spark"), Seq(
+          StructField("a", StringType("UNICODE"), nullable = true)
+        )),
+      CsvToStructsTestCase("26/08/2015", "UTF8_BINARY", "'time Timestamp'",
+        ", map('timestampFormat', 'dd/MM/yyyy')", Row(
+          new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse("2015-08-26 
00:00:00.0")
+        ), Seq(
+          StructField("time", TimestampType, nullable = true)
+        ))
+    )
+
+    // Supported collations
+    testCases.foreach(t => {
+      val query =
+        s"""
+           |select from_csv('${t.input}', ${t.schema} ${t.options})
+           |""".stripMargin
+      // Result
+      withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
+        val testQuery = sql(query)
+        val queryResult = testQuery.collect().head
+        checkAnswer(testQuery, Row(t.result))
+        val dataType = StructType(t.structFields)
+        assert(testQuery.schema.fields.head.dataType.sameType(dataType))
+      }
+    })
+  }
+
+  test("Support SchemaOfCsv csv expression with collation") {
+    case class SchemaOfCsvTestCase(
+      input: String,
+      collationName: String,
+      result: String
+    )
+
+    val testCases = Seq(
+      SchemaOfCsvTestCase("1", "UTF8_BINARY", "STRUCT<_c0: INT>"),
+      SchemaOfCsvTestCase("true,0.8", "UTF8_BINARY_LCASE",
+        "STRUCT<_c0: BOOLEAN, _c1: DOUBLE>"),
+      SchemaOfCsvTestCase("2015-08-26", "UNICODE", "STRUCT<_c0: DATE>"),
+      SchemaOfCsvTestCase("abc", "UNICODE_CI",
+        "STRUCT<_c0: STRING>")
+    )
+
+    // Supported collations
+    testCases.foreach(t => {
+      val query =
+        s"""
+           |select schema_of_csv('${t.input}')
+           |""".stripMargin
+      // Result
+      withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
+        val testQuery = sql(query)
+        checkAnswer(testQuery, Row(t.result))
+        val dataType = StringType(t.collationName)
+        assert(testQuery.schema.fields.head.dataType.sameType(dataType))
+      }
+    })
+  }
+
+  test("Support StructsToCsv csv expression with collation") {
+    case class StructsToCsvTestCase(
+     input: String,
+     collationName: String,
+     result: String
+    )
+
+    val testCases = Seq(
+      StructsToCsvTestCase("named_struct('a', 1, 'b', 2)", "UTF8_BINARY", 
"1,2"),
+      StructsToCsvTestCase("named_struct('A', true, 'B', 2.0)", 
"UTF8_BINARY_LCASE", "true,2.0"),
+      StructsToCsvTestCase("named_struct()", "UNICODE", null),
+      StructsToCsvTestCase("named_struct('time', to_timestamp('2015-08-26'))", 
"UNICODE_CI",
+        "2015-08-26T00:00:00.000-07:00")
+    )
+
+    // Supported collations
+    testCases.foreach(t => {
+      val query =
+        s"""
+           |select to_csv(${t.input})
+           |""".stripMargin
+      // Result
+      withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
+        val testQuery = sql(query)
+        checkAnswer(testQuery, Row(t.result))
+        val dataType = StringType(t.collationName)
+        assert(testQuery.schema.fields.head.dataType.sameType(dataType))
+      }
+    })
+  }
+
   test("Conv expression with collation") {
     // Supported collations
     case class ConvTestCase(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to