This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6b1f1a6932f2 [SPARK-52976][PYTHON] Fix Python UDF not accepting 
collated string as input param/return type
6b1f1a6932f2 is described below

commit 6b1f1a6932f2c7dc33b0064db0b2c44f5361a710
Author: ilicmarkodb <marko.i...@databricks.com>
AuthorDate: Fri Aug 8 10:25:16 2025 +0800

    [SPARK-52976][PYTHON] Fix Python UDF not accepting collated string as input 
param/return type
    
    ### What changes were proposed in this pull request?
    Fix Python UDF not accepting collated strings as input param/return type.
    
    ### Why are the changes needed?
    Bug fix.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    New tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #51688 from ilicmarkodb/fix_collated_string_as_input_of_python_udf.
    
    Authored-by: ilicmarkodb <marko.i...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 python/pyspark/sql/tests/test_udf.py               | 23 ++++++++++++++
 python/pyspark/sql/tests/test_udtf.py              | 35 ++++++++++++++++++++++
 .../org/apache/spark/sql/types/DataType.scala      | 34 ++++++++++++++++-----
 .../org/apache/spark/sql/types/DataTypeSuite.scala | 18 +++++------
 .../apache/spark/sql/execution/command/ddl.scala   |  2 +-
 .../sql/execution/python/ArrowEvalPythonExec.scala |  3 +-
 .../execution/python/ArrowEvalPythonUDTFExec.scala |  3 +-
 .../sql/execution/python/EvaluatePython.scala      |  2 +-
 8 files changed, 100 insertions(+), 20 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index bd4db5306cb7..af85a73581a7 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -1379,6 +1379,29 @@ class BaseUDFTestsMixin(object):
         result = empty_df.select(add1("id"))
         self.assertEqual(result.collect(), [])
 
+    def test_udf_with_collated_string_types(self):
+        @udf("string collate fr")
+        def my_udf(input_val):
+            return "%s - %s" % (type(input_val), input_val)
+
+        string_types = [
+            StringType(),
+            StringType("UTF8_BINARY"),
+            StringType("UTF8_LCASE"),
+            StringType("UNICODE"),
+        ]
+        data = [("hello",)]
+        expected = "<class 'str'> - hello"
+
+        for string_type in string_types:
+            schema = StructType([StructField("input_col", string_type, True)])
+            df = self.spark.createDataFrame(data, schema=schema)
+            df_result = df.select(my_udf(df.input_col).alias("result"))
+            row = df_result.collect()[0][0]
+            self.assertEqual(row, expected)
+            result_type = df_result.schema["result"].dataType
+            self.assertEqual(result_type, StringType("fr"))
+
 
 class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 1c473daff74e..c8d7d9f14563 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -3490,6 +3490,41 @@ class UDTFArrowTestsMixin(LegacyUDTFArrowTestsMixin):
                     udtf(TestUDTF, returnType=ret_type)().collect()
 
 
+def test_udtf_with_collated_string_types(self):
+    @udtf(
+        "out1 string, out2 string collate UTF8_BINARY, out3 string collate 
UTF8_LCASE,"
+        " out4 string collate UNICODE"
+    )
+    class MyUDTF:
+        def eval(self, v1, v2, v3, v4):
+            yield (v1 + "1", v2 + "2", v3 + "3", v4 + "4")
+
+    schema = StructType(
+        [
+            StructField("col1", StringType(), True),
+            StructField("col2", StringType("UTF8_BINARY"), True),
+            StructField("col3", StringType("UTF8_LCASE"), True),
+            StructField("col4", StringType("UNICODE"), True),
+        ]
+    )
+    df = self.spark.createDataFrame([("hello",) * 4], schema=schema)
+
+    df_out = df.select(MyUDTF(df.col1, df.col2, df.col3, df.col4).alias("out"))
+    result_df = df_out.select("out.*")
+
+    expected_row = ("hello1", "hello2", "hello3", "hello4")
+    self.assertEqual(result_df.collect()[0], expected_row)
+
+    expected_output_types = [
+        StringType(),
+        StringType("UTF8_BINARY"),
+        StringType("UTF8_LCASE"),
+        StringType("UNICODE"),
+    ]
+    for idx, field in enumerate(result_df.schema.fields):
+        self.assertEqual(field.dataType, expected_output_types[idx])
+
+
 class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
     @classmethod
     def setUpClass(cls):
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 0bf1440f2944..fce0807a7d23 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -448,15 +448,35 @@ object DataType {
   }
 
   /**
-   * Check if `from` is equal to `to` type except for collations, which are 
checked to be
-   * compatible so that data of type `from` can be interpreted as of type `to`.
+   * Compares two data types, ignoring compatible collation of StringType. If 
`checkComplexTypes`
+   * is true, it will also ignore collations for nested types.
    */
-  private[sql] def equalsIgnoreCompatibleCollation(from: DataType, to: 
DataType): Boolean = {
-    (from, to) match {
-      // String types with possibly different collations are compatible.
-      case (a: StringType, b: StringType) => a.constraint == b.constraint
+  private[sql] def equalsIgnoreCompatibleCollation(
+      from: DataType,
+      to: DataType,
+      checkComplexTypes: Boolean = true): Boolean = {
+    def transform: PartialFunction[DataType, DataType] = {
+      case dt @ (_: CharType | _: VarcharType) => dt
+      case _: StringType => StringType
+    }
 
-      case (fromDataType, toDataType) => fromDataType == toDataType
+    if (checkComplexTypes) {
+      from.transformRecursively(transform) == 
to.transformRecursively(transform)
+    } else {
+      (from, to) match {
+        case (a: StringType, b: StringType) => a.constraint == b.constraint
+
+        case (fromDataType, toDataType) => fromDataType == toDataType
+      }
+    }
+  }
+
+  private[sql] def equalsIgnoreCompatibleCollation(
+      from: Seq[DataType],
+      to: Seq[DataType]): Boolean = {
+    from.length == to.length &&
+    from.zip(to).forall { case (fromDataType, toDataType) =>
+      equalsIgnoreCompatibleCollation(fromDataType, toDataType)
     }
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index af3c36ed621f..c88b0fd99646 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -880,7 +880,7 @@ class DataTypeSuite extends SparkFunSuite {
   checkEqualsIgnoreCompatibleCollation(
     ArrayType(StringType),
     ArrayType(StringType("UTF8_LCASE")),
-    expected = false
+    expected = true
   )
   checkEqualsIgnoreCompatibleCollation(
     ArrayType(StringType),
@@ -890,7 +890,7 @@ class DataTypeSuite extends SparkFunSuite {
   checkEqualsIgnoreCompatibleCollation(
     ArrayType(ArrayType(StringType)),
     ArrayType(ArrayType(StringType("UTF8_LCASE"))),
-    expected = false
+    expected = true
   )
   checkEqualsIgnoreCompatibleCollation(
     ArrayType(ArrayType(StringType)),
@@ -915,12 +915,12 @@ class DataTypeSuite extends SparkFunSuite {
   checkEqualsIgnoreCompatibleCollation(
     MapType(StringType, StringType),
     MapType(StringType, StringType("UTF8_LCASE")),
-    expected = false
+    expected = true
   )
   checkEqualsIgnoreCompatibleCollation(
     MapType(StringType("UTF8_LCASE"), StringType),
     MapType(StringType, StringType),
-    expected = false
+    expected = true
   )
   checkEqualsIgnoreCompatibleCollation(
     MapType(StringType("UTF8_LCASE"), StringType),
@@ -945,7 +945,7 @@ class DataTypeSuite extends SparkFunSuite {
   checkEqualsIgnoreCompatibleCollation(
     MapType(StringType("UTF8_LCASE"), ArrayType(StringType)),
     MapType(StringType("UTF8_LCASE"), ArrayType(StringType("UTF8_LCASE"))),
-    expected = false
+    expected = true
   )
   checkEqualsIgnoreCompatibleCollation(
     MapType(StringType("UTF8_LCASE"), ArrayType(StringType)),
@@ -970,7 +970,7 @@ class DataTypeSuite extends SparkFunSuite {
   checkEqualsIgnoreCompatibleCollation(
     MapType(ArrayType(StringType), IntegerType),
     MapType(ArrayType(StringType("UTF8_LCASE")), IntegerType),
-    expected = false
+    expected = true
   )
   checkEqualsIgnoreCompatibleCollation(
     MapType(ArrayType(StringType("UTF8_LCASE")), IntegerType),
@@ -1000,7 +1000,7 @@ class DataTypeSuite extends SparkFunSuite {
   checkEqualsIgnoreCompatibleCollation(
     StructType(StructField("a", StringType) :: Nil),
     StructType(StructField("a", StringType("UTF8_LCASE")) :: Nil),
-    expected = false
+    expected = true
   )
   checkEqualsIgnoreCompatibleCollation(
     StructType(StructField("a", StringType) :: Nil),
@@ -1025,7 +1025,7 @@ class DataTypeSuite extends SparkFunSuite {
   checkEqualsIgnoreCompatibleCollation(
     StructType(StructField("a", ArrayType(StringType)) :: Nil),
     StructType(StructField("a", ArrayType(StringType("UTF8_LCASE"))) :: Nil),
-    expected = false
+    expected = true
   )
   checkEqualsIgnoreCompatibleCollation(
     StructType(StructField("a", ArrayType(StringType)) :: Nil),
@@ -1050,7 +1050,7 @@ class DataTypeSuite extends SparkFunSuite {
   checkEqualsIgnoreCompatibleCollation(
     StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil),
     StructType(StructField("a", MapType(StringType("UTF8_LCASE"), 
IntegerType)) :: Nil),
-    expected = false
+    expected = true
   )
   checkEqualsIgnoreCompatibleCollation(
     StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil),
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 7df687a4963e..8a4f586edfe0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -465,7 +465,7 @@ case class AlterTableChangeColumnCommand(
   // when altering column. Only changes in collation of data type or its 
nested types (recursively)
   // are allowed.
   private def canEvolveType(from: StructField, to: StructField): Boolean = {
-    DataType.equalsIgnoreCompatibleCollation(from.dataType, to.dataType)
+    DataType.equalsIgnoreCompatibleCollation(from.dataType, to.dataType, 
checkComplexTypes = false)
   }
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 9ec454731e4a..72038f8d4519 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.types.{StructType, UserDefinedType}
+import org.apache.spark.sql.types.DataType.equalsIgnoreCompatibleCollation
 
 /**
  * Grouped a iterator into batches.
@@ -128,7 +129,7 @@ class ArrowEvalPythonEvaluatorFactory(
 
     columnarBatchIter.flatMap { batch =>
       val actualDataTypes = (0 until batch.numCols()).map(i => 
batch.column(i).dataType())
-      if (outputTypes != actualDataTypes) {
+      if (!equalsIgnoreCompatibleCollation(outputTypes, actualDataTypes)) {
         throw QueryExecutionErrors.arrowDataTypeMismatchError(
           "pandas_udf()", outputTypes, actualDataTypes)
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
index 6a6b08a97330..ae1982ecec20 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.types.{StructType, UserDefinedType}
+import org.apache.spark.sql.types.DataType.equalsIgnoreCompatibleCollation
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
 
 /**
@@ -84,7 +85,7 @@ case class ArrowEvalPythonUDTFExec(
 
       val actualDataTypes = (0 until flattenedBatch.numCols()).map(
         i => flattenedBatch.column(i).dataType())
-      if (outputTypes != actualDataTypes) {
+      if (!equalsIgnoreCompatibleCollation(outputTypes, actualDataTypes)) {
         throw QueryExecutionErrors.arrowDataTypeMismatchError(
           "Python UDTF", outputTypes, actualDataTypes)
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index eb6fad8d1a3c..5d117a67e6be 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -79,7 +79,7 @@ object EvaluatePython {
 
     case (d: Decimal, _) => d.toJavaBigDecimal
 
-    case (s: UTF8String, StringType) => s.toString
+    case (s: UTF8String, _: StringType) => s.toString
 
     case (other, _) => other
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to