This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6b1f1a6932f2 [SPARK-52976][PYTHON] Fix Python UDF not accepting collated string as input param/return type 6b1f1a6932f2 is described below commit 6b1f1a6932f2c7dc33b0064db0b2c44f5361a710 Author: ilicmarkodb <marko.i...@databricks.com> AuthorDate: Fri Aug 8 10:25:16 2025 +0800 [SPARK-52976][PYTHON] Fix Python UDF not accepting collated string as input param/return type ### What changes were proposed in this pull request? Fix Python UDF not accepting collated strings as input param/return type. ### Why are the changes needed? Bug fix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51688 from ilicmarkodb/fix_collated_string_as_input_of_python_udf. Authored-by: ilicmarkodb <marko.i...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- python/pyspark/sql/tests/test_udf.py | 23 ++++++++++++++ python/pyspark/sql/tests/test_udtf.py | 35 ++++++++++++++++++++++ .../org/apache/spark/sql/types/DataType.scala | 34 ++++++++++++++++----- .../org/apache/spark/sql/types/DataTypeSuite.scala | 18 +++++------ .../apache/spark/sql/execution/command/ddl.scala | 2 +- .../sql/execution/python/ArrowEvalPythonExec.scala | 3 +- .../execution/python/ArrowEvalPythonUDTFExec.scala | 3 +- .../sql/execution/python/EvaluatePython.scala | 2 +- 8 files changed, 100 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index bd4db5306cb7..af85a73581a7 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -1379,6 +1379,29 @@ class BaseUDFTestsMixin(object): result = empty_df.select(add1("id")) self.assertEqual(result.collect(), []) + def test_udf_with_collated_string_types(self): + @udf("string collate fr") + def my_udf(input_val): + return "%s - %s" % (type(input_val), input_val) + + string_types = [ + StringType(), + StringType("UTF8_BINARY"), + StringType("UTF8_LCASE"), + StringType("UNICODE"), + ] + data = [("hello",)] + expected = "<class 'str'> - hello" + + for string_type in string_types: + schema = StructType([StructField("input_col", string_type, True)]) + df = self.spark.createDataFrame(data, schema=schema) + df_result = df.select(my_udf(df.input_col).alias("result")) + row = df_result.collect()[0][0] + self.assertEqual(row, expected) + result_type = df_result.schema["result"].dataType + self.assertEqual(result_type, StringType("fr")) + class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 1c473daff74e..c8d7d9f14563 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -3490,6 +3490,41 @@ class UDTFArrowTestsMixin(LegacyUDTFArrowTestsMixin): udtf(TestUDTF, returnType=ret_type)().collect() +def test_udtf_with_collated_string_types(self): + @udtf( + "out1 string, out2 string collate UTF8_BINARY, out3 string collate UTF8_LCASE," + " out4 string collate UNICODE" + ) + class MyUDTF: + def eval(self, v1, v2, v3, v4): + yield (v1 + "1", v2 + "2", v3 + "3", v4 + "4") + + schema = StructType( + [ + StructField("col1", StringType(), True), + StructField("col2", StringType("UTF8_BINARY"), True), + StructField("col3", StringType("UTF8_LCASE"), True), + StructField("col4", StringType("UNICODE"), True), + ] + ) + df = self.spark.createDataFrame([("hello",) * 4], schema=schema) + + df_out = df.select(MyUDTF(df.col1, df.col2, df.col3, df.col4).alias("out")) + result_df = df_out.select("out.*") + + expected_row = ("hello1", "hello2", "hello3", "hello4") + self.assertEqual(result_df.collect()[0], expected_row) + + expected_output_types = [ + StringType(), + StringType("UTF8_BINARY"), + StringType("UTF8_LCASE"), + StringType("UNICODE"), + ] + for idx, field in enumerate(result_df.schema.fields): + self.assertEqual(field.dataType, expected_output_types[idx]) + + class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase): @classmethod def setUpClass(cls): diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0bf1440f2944..fce0807a7d23 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -448,15 +448,35 @@ object DataType { } /** - * Check if `from` is equal to `to` type except for collations, which are checked to be - * compatible so that data of type `from` can be interpreted as of type `to`. + * Compares two data types, ignoring compatible collation of StringType. If `checkComplexTypes` + * is true, it will also ignore collations for nested types. */ - private[sql] def equalsIgnoreCompatibleCollation(from: DataType, to: DataType): Boolean = { - (from, to) match { - // String types with possibly different collations are compatible. - case (a: StringType, b: StringType) => a.constraint == b.constraint + private[sql] def equalsIgnoreCompatibleCollation( + from: DataType, + to: DataType, + checkComplexTypes: Boolean = true): Boolean = { + def transform: PartialFunction[DataType, DataType] = { + case dt @ (_: CharType | _: VarcharType) => dt + case _: StringType => StringType + } - case (fromDataType, toDataType) => fromDataType == toDataType + if (checkComplexTypes) { + from.transformRecursively(transform) == to.transformRecursively(transform) + } else { + (from, to) match { + case (a: StringType, b: StringType) => a.constraint == b.constraint + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } + } + + private[sql] def equalsIgnoreCompatibleCollation( + from: Seq[DataType], + to: Seq[DataType]): Boolean = { + from.length == to.length && + from.zip(to).forall { case (fromDataType, toDataType) => + equalsIgnoreCompatibleCollation(fromDataType, toDataType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index af3c36ed621f..c88b0fd99646 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -880,7 +880,7 @@ class DataTypeSuite extends SparkFunSuite { checkEqualsIgnoreCompatibleCollation( ArrayType(StringType), ArrayType(StringType("UTF8_LCASE")), - expected = false + expected = true ) checkEqualsIgnoreCompatibleCollation( ArrayType(StringType), @@ -890,7 +890,7 @@ class DataTypeSuite extends SparkFunSuite { checkEqualsIgnoreCompatibleCollation( ArrayType(ArrayType(StringType)), ArrayType(ArrayType(StringType("UTF8_LCASE"))), - expected = false + expected = true ) checkEqualsIgnoreCompatibleCollation( ArrayType(ArrayType(StringType)), @@ -915,12 +915,12 @@ class DataTypeSuite extends SparkFunSuite { checkEqualsIgnoreCompatibleCollation( MapType(StringType, StringType), MapType(StringType, StringType("UTF8_LCASE")), - expected = false + expected = true ) checkEqualsIgnoreCompatibleCollation( MapType(StringType("UTF8_LCASE"), StringType), MapType(StringType, StringType), - expected = false + expected = true ) checkEqualsIgnoreCompatibleCollation( MapType(StringType("UTF8_LCASE"), StringType), @@ -945,7 +945,7 @@ class DataTypeSuite extends SparkFunSuite { checkEqualsIgnoreCompatibleCollation( MapType(StringType("UTF8_LCASE"), ArrayType(StringType)), MapType(StringType("UTF8_LCASE"), ArrayType(StringType("UTF8_LCASE"))), - expected = false + expected = true ) checkEqualsIgnoreCompatibleCollation( MapType(StringType("UTF8_LCASE"), ArrayType(StringType)), @@ -970,7 +970,7 @@ class DataTypeSuite extends SparkFunSuite { checkEqualsIgnoreCompatibleCollation( MapType(ArrayType(StringType), IntegerType), MapType(ArrayType(StringType("UTF8_LCASE")), IntegerType), - expected = false + expected = true ) checkEqualsIgnoreCompatibleCollation( MapType(ArrayType(StringType("UTF8_LCASE")), IntegerType), @@ -1000,7 +1000,7 @@ class DataTypeSuite extends SparkFunSuite { checkEqualsIgnoreCompatibleCollation( StructType(StructField("a", StringType) :: Nil), StructType(StructField("a", StringType("UTF8_LCASE")) :: Nil), - expected = false + expected = true ) checkEqualsIgnoreCompatibleCollation( StructType(StructField("a", StringType) :: Nil), @@ -1025,7 +1025,7 @@ class DataTypeSuite extends SparkFunSuite { checkEqualsIgnoreCompatibleCollation( StructType(StructField("a", ArrayType(StringType)) :: Nil), StructType(StructField("a", ArrayType(StringType("UTF8_LCASE"))) :: Nil), - expected = false + expected = true ) checkEqualsIgnoreCompatibleCollation( StructType(StructField("a", ArrayType(StringType)) :: Nil), @@ -1050,7 +1050,7 @@ class DataTypeSuite extends SparkFunSuite { checkEqualsIgnoreCompatibleCollation( StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil), StructType(StructField("a", MapType(StringType("UTF8_LCASE"), IntegerType)) :: Nil), - expected = false + expected = true ) checkEqualsIgnoreCompatibleCollation( StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 7df687a4963e..8a4f586edfe0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -465,7 +465,7 @@ case class AlterTableChangeColumnCommand( // when altering column. Only changes in collation of data type or its nested types (recursively) // are allowed. private def canEvolveType(from: StructField, to: StructField): Boolean = { - DataType.equalsIgnoreCompatibleCollation(from.dataType, to.dataType) + DataType.equalsIgnoreCompatibleCollation(from.dataType, to.dataType, checkComplexTypes = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 9ec454731e4a..72038f8d4519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.types.{StructType, UserDefinedType} +import org.apache.spark.sql.types.DataType.equalsIgnoreCompatibleCollation /** * Grouped a iterator into batches. @@ -128,7 +129,7 @@ class ArrowEvalPythonEvaluatorFactory( columnarBatchIter.flatMap { batch => val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) - if (outputTypes != actualDataTypes) { + if (!equalsIgnoreCompatibleCollation(outputTypes, actualDataTypes)) { throw QueryExecutionErrors.arrowDataTypeMismatchError( "pandas_udf()", outputTypes, actualDataTypes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala index 6a6b08a97330..ae1982ecec20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.types.{StructType, UserDefinedType} +import org.apache.spark.sql.types.DataType.equalsIgnoreCompatibleCollation import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** @@ -84,7 +85,7 @@ case class ArrowEvalPythonUDTFExec( val actualDataTypes = (0 until flattenedBatch.numCols()).map( i => flattenedBatch.column(i).dataType()) - if (outputTypes != actualDataTypes) { + if (!equalsIgnoreCompatibleCollation(outputTypes, actualDataTypes)) { throw QueryExecutionErrors.arrowDataTypeMismatchError( "Python UDTF", outputTypes, actualDataTypes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index eb6fad8d1a3c..5d117a67e6be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -79,7 +79,7 @@ object EvaluatePython { case (d: Decimal, _) => d.toJavaBigDecimal - case (s: UTF8String, StringType) => s.toString + case (s: UTF8String, _: StringType) => s.toString case (other, _) => other } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org