This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 227cd8bb8f5 [SPARK-45523][PYTHON] Return useful error message if UDTF 
returns None for any non-nullable column
227cd8bb8f5 is described below

commit 227cd8bb8f5d8d442f225057db39591b3c630f46
Author: Daniel Tenedorio <daniel.tenedo...@databricks.com>
AuthorDate: Fri Oct 20 15:22:45 2023 -0700

    [SPARK-45523][PYTHON] Return useful error message if UDTF returns None for 
any non-nullable column
    
    ### What changes were proposed in this pull request?
    
    This PR updates Python UDTF evaluation to return a useful error message if 
UDTF returns None for any non-nullable column.
    
    This implementation also checks recursively for None values in subfields of 
array/struct/map columns as well.
    
    For example:
    
    ```
    from pyspark.sql.functions import AnalyzeResult
    from pyspark.sql.types import ArrayType, IntegerType, StringType, StructType
    
    class Tvf:
        staticmethod
        def analyze(*args):
            return AnalyzeResult(
                schema=StructType()
                    .add("result", ArrayType(IntegerType(), 
containsNull=False), True)
                )
    
        def eval(self, *args):
            yield [1, 2, 3, 4],
    
        def terminate(self):
            yield [1, 2, None, 3],
    ```
    
    ```
    SELECT * FROM Tvf(TABLE(VALUES (0), (1)))
    
    > org.apache.spark.api.python.PythonException
    [UDTF_EXEC_ERROR] User defined table function encountered an error in the 
'eval' or
    'terminate' method: Column 0 within a returned row had a value of None, 
either directly or
    within array/struct/map subfields, but the corresponding column type was 
declared as non
    nullable; please update the UDTF to return a non-None value at this 
location or otherwise
    declare the column type as nullable.
    ```
    
    ### Why are the changes needed?
    
    Previously this case returned a null pointer exception.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, see above.
    
    ### How was this patch tested?
    
    This PR adds new test coverage.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43356 from dtenedor/improve-errors-null-checks.
    
    Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/worker.py                           |  72 +++++
 .../sql-tests/analyzer-results/udtf/udtf.sql.out   |  60 ++++
 .../test/resources/sql-tests/inputs/udtf/udtf.sql  |  12 +
 .../resources/sql-tests/results/udtf/udtf.sql.out  |  90 ++++++
 .../apache/spark/sql/IntegratedUDFTestUtils.scala  | 359 ++++++++++++++++++++-
 .../org/apache/spark/sql/SQLQueryTestSuite.scala   |  43 ++-
 6 files changed, 626 insertions(+), 10 deletions(-)

diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index df7dd1bc2f7..b1f59e1619f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -52,9 +52,13 @@ from pyspark.sql.pandas.serializers import (
 )
 from pyspark.sql.pandas.types import to_arrow_type
 from pyspark.sql.types import (
+    ArrayType,
     BinaryType,
+    DataType,
+    MapType,
     Row,
     StringType,
+    StructField,
     StructType,
     _create_row,
     _parse_datatype_json_string,
@@ -841,6 +845,71 @@ def read_udtf(pickleSer, infile, eval_type):
             "the query again."
         )
 
+    # This determines which result columns have nullable types.
+    def check_nullable_column(i: int, data_type: DataType, nullable: bool) -> 
None:
+        if not nullable:
+            nullable_columns.add(i)
+        elif isinstance(data_type, ArrayType):
+            check_nullable_column(i, data_type.elementType, 
data_type.containsNull)
+        elif isinstance(data_type, StructType):
+            for subfield in data_type.fields:
+                check_nullable_column(i, subfield.dataType, subfield.nullable)
+        elif isinstance(data_type, MapType):
+            check_nullable_column(i, data_type.valueType, 
data_type.valueContainsNull)
+
+    nullable_columns: set[int] = set()
+    for i, field in enumerate(return_type.fields):
+        check_nullable_column(i, field.dataType, field.nullable)
+
+    # Compares each UDTF output row against the output schema for this 
particular UDTF call,
+    # raising an error if the two are incompatible.
+    def check_output_row_against_schema(row: Any, expected_schema: StructType) 
-> None:
+        for result_column_index in nullable_columns:
+
+            def check_for_none_in_non_nullable_column(
+                value: Any, data_type: DataType, nullable: bool
+            ) -> None:
+                if value is None and not nullable:
+                    raise PySparkRuntimeError(
+                        error_class="UDTF_EXEC_ERROR",
+                        message_parameters={
+                            "method_name": "eval' or 'terminate",
+                            "error": f"Column {result_column_index} within a 
returned row had a "
+                            + "value of None, either directly or within 
array/struct/map "
+                            + "subfields, but the corresponding column type 
was declared as "
+                            + "non-nullable; please update the UDTF to return 
a non-None value at "
+                            + "this location or otherwise declare the column 
type as nullable.",
+                        },
+                    )
+                elif (
+                    isinstance(data_type, ArrayType)
+                    and isinstance(value, list)
+                    and not data_type.containsNull
+                ):
+                    for sub_value in value:
+                        check_for_none_in_non_nullable_column(
+                            sub_value, data_type.elementType, 
data_type.containsNull
+                        )
+                elif isinstance(data_type, StructType) and isinstance(value, 
Row):
+                    for i in range(len(value)):
+                        check_for_none_in_non_nullable_column(
+                            value[i], data_type[i].dataType, 
data_type[i].nullable
+                        )
+                elif isinstance(data_type, MapType) and isinstance(value, 
dict):
+                    for map_key, map_value in value.items():
+                        check_for_none_in_non_nullable_column(
+                            map_key, data_type.keyType, nullable=False
+                        )
+                        check_for_none_in_non_nullable_column(
+                            map_value, data_type.valueType, 
data_type.valueContainsNull
+                        )
+
+            field: StructField = expected_schema[result_column_index]
+            if row is not None:
+                check_for_none_in_non_nullable_column(
+                    list(row)[result_column_index], field.dataType, 
field.nullable
+                )
+
     if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
 
         def wrap_arrow_udtf(f, return_type):
@@ -879,6 +948,8 @@ def read_udtf(pickleSer, infile, eval_type):
                 verify_pandas_result(
                     result, return_type, assign_cols_by_name=False, 
truncate_return_schema=False
                 )
+                for result_tuple in result.itertuples():
+                    check_output_row_against_schema(list(result_tuple), 
return_type)
                 return result
 
             # Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
@@ -973,6 +1044,7 @@ def read_udtf(pickleSer, infile, eval_type):
                             },
                         )
 
+                check_output_row_against_schema(result, return_type)
                 return toInternal(result)
 
             # Evaluate the function and return a tuple back to the executor.
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
index 1b923442207..078bd790a84 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
@@ -452,6 +452,66 @@ org.apache.spark.sql.AnalysisException
 }
 
 
+-- !query
+SELECT * FROM InvalidEvalReturnsNoneToNonNullableColumnScalarType(TABLE(t2))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+SELECT * FROM InvalidEvalReturnsNoneToNonNullableColumnArrayType(TABLE(t2))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+SELECT * FROM 
InvalidEvalReturnsNoneToNonNullableColumnArrayElementType(TABLE(t2))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+SELECT * FROM InvalidEvalReturnsNoneToNonNullableColumnStructType(TABLE(t2))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+SELECT * FROM InvalidEvalReturnsNoneToNonNullableColumnMapType(TABLE(t2))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+SELECT * FROM 
InvalidTerminateReturnsNoneToNonNullableColumnScalarType(TABLE(t2))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+SELECT * FROM 
InvalidTerminateReturnsNoneToNonNullableColumnArrayType(TABLE(t2))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+SELECT * FROM 
InvalidTerminateReturnsNoneToNonNullableColumnArrayElementType(TABLE(t2))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+SELECT * FROM 
InvalidTerminateReturnsNoneToNonNullableColumnStructType(TABLE(t2))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+SELECT * FROM InvalidTerminateReturnsNoneToNonNullableColumnMapType(TABLE(t2))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
 -- !query
 DROP VIEW t1
 -- !query analysis
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql 
b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
index 6d34b91e2f1..3d7d0bb3251 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
@@ -104,6 +104,18 @@ SELECT * FROM
     VALUES (0), (1) AS t(col)
     JOIN LATERAL
     UDTFInvalidOrderByWithoutPartitionBy(TABLE(t2) PARTITION BY partition_col);
+-- The following UDTF calls should fail because the UDTF's 'eval' or 
'terminate' method returns None
+-- to a non-nullable column, either directly or within an array/struct/map 
subfield.
+SELECT * FROM InvalidEvalReturnsNoneToNonNullableColumnScalarType(TABLE(t2));
+SELECT * FROM InvalidEvalReturnsNoneToNonNullableColumnArrayType(TABLE(t2));
+SELECT * FROM 
InvalidEvalReturnsNoneToNonNullableColumnArrayElementType(TABLE(t2));
+SELECT * FROM InvalidEvalReturnsNoneToNonNullableColumnStructType(TABLE(t2));
+SELECT * FROM InvalidEvalReturnsNoneToNonNullableColumnMapType(TABLE(t2));
+SELECT * FROM 
InvalidTerminateReturnsNoneToNonNullableColumnScalarType(TABLE(t2));
+SELECT * FROM 
InvalidTerminateReturnsNoneToNonNullableColumnArrayType(TABLE(t2));
+SELECT * FROM 
InvalidTerminateReturnsNoneToNonNullableColumnArrayElementType(TABLE(t2));
+SELECT * FROM 
InvalidTerminateReturnsNoneToNonNullableColumnStructType(TABLE(t2));
+SELECT * FROM InvalidTerminateReturnsNoneToNonNullableColumnMapType(TABLE(t2));
 
 -- cleanup
 DROP VIEW t1;
diff --git a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out 
b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
index 11295c43d8c..3317f5fada7 100644
--- a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
@@ -525,6 +525,96 @@ org.apache.spark.sql.AnalysisException
 }
 
 
+-- !query
+SELECT * FROM InvalidEvalReturnsNoneToNonNullableColumnScalarType(TABLE(t2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.api.python.PythonException
+pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EXEC_ERROR] User 
defined table function encountered an error in the 'eval' or 'terminate' 
method: Column 0 within a returned row had a value of None, either directly or 
within array/struct/map subfields, but the corresponding column type was 
declared as non-nullable; please update the UDTF to return a non-None value at 
this location or otherwise declare the column type as nullable.
+
+
+-- !query
+SELECT * FROM InvalidEvalReturnsNoneToNonNullableColumnArrayType(TABLE(t2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.api.python.PythonException
+pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EXEC_ERROR] User 
defined table function encountered an error in the 'eval' or 'terminate' 
method: Column 0 within a returned row had a value of None, either directly or 
within array/struct/map subfields, but the corresponding column type was 
declared as non-nullable; please update the UDTF to return a non-None value at 
this location or otherwise declare the column type as nullable.
+
+
+-- !query
+SELECT * FROM 
InvalidEvalReturnsNoneToNonNullableColumnArrayElementType(TABLE(t2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.api.python.PythonException
+pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EXEC_ERROR] User 
defined table function encountered an error in the 'eval' or 'terminate' 
method: Column 0 within a returned row had a value of None, either directly or 
within array/struct/map subfields, but the corresponding column type was 
declared as non-nullable; please update the UDTF to return a non-None value at 
this location or otherwise declare the column type as nullable.
+
+
+-- !query
+SELECT * FROM InvalidEvalReturnsNoneToNonNullableColumnStructType(TABLE(t2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.api.python.PythonException
+pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EXEC_ERROR] User 
defined table function encountered an error in the 'eval' or 'terminate' 
method: Column 0 within a returned row had a value of None, either directly or 
within array/struct/map subfields, but the corresponding column type was 
declared as non-nullable; please update the UDTF to return a non-None value at 
this location or otherwise declare the column type as nullable.
+
+
+-- !query
+SELECT * FROM InvalidEvalReturnsNoneToNonNullableColumnMapType(TABLE(t2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.api.python.PythonException
+pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EXEC_ERROR] User 
defined table function encountered an error in the 'eval' or 'terminate' 
method: Column 0 within a returned row had a value of None, either directly or 
within array/struct/map subfields, but the corresponding column type was 
declared as non-nullable; please update the UDTF to return a non-None value at 
this location or otherwise declare the column type as nullable.
+
+
+-- !query
+SELECT * FROM 
InvalidTerminateReturnsNoneToNonNullableColumnScalarType(TABLE(t2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.api.python.PythonException
+pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EXEC_ERROR] User 
defined table function encountered an error in the 'eval' or 'terminate' 
method: Column 0 within a returned row had a value of None, either directly or 
within array/struct/map subfields, but the corresponding column type was 
declared as non-nullable; please update the UDTF to return a non-None value at 
this location or otherwise declare the column type as nullable.
+
+
+-- !query
+SELECT * FROM 
InvalidTerminateReturnsNoneToNonNullableColumnArrayType(TABLE(t2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.api.python.PythonException
+pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EXEC_ERROR] User 
defined table function encountered an error in the 'eval' or 'terminate' 
method: Column 0 within a returned row had a value of None, either directly or 
within array/struct/map subfields, but the corresponding column type was 
declared as non-nullable; please update the UDTF to return a non-None value at 
this location or otherwise declare the column type as nullable.
+
+
+-- !query
+SELECT * FROM 
InvalidTerminateReturnsNoneToNonNullableColumnArrayElementType(TABLE(t2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.api.python.PythonException
+pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EXEC_ERROR] User 
defined table function encountered an error in the 'eval' or 'terminate' 
method: Column 0 within a returned row had a value of None, either directly or 
within array/struct/map subfields, but the corresponding column type was 
declared as non-nullable; please update the UDTF to return a non-None value at 
this location or otherwise declare the column type as nullable.
+
+
+-- !query
+SELECT * FROM 
InvalidTerminateReturnsNoneToNonNullableColumnStructType(TABLE(t2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.api.python.PythonException
+pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EXEC_ERROR] User 
defined table function encountered an error in the 'eval' or 'terminate' 
method: Column 0 within a returned row had a value of None, either directly or 
within array/struct/map subfields, but the corresponding column type was 
declared as non-nullable; please update the UDTF to return a non-None value at 
this location or otherwise declare the column type as nullable.
+
+
+-- !query
+SELECT * FROM InvalidTerminateReturnsNoneToNonNullableColumnMapType(TABLE(t2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.api.python.PythonException
+pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_EXEC_ERROR] User 
defined table function encountered an error in the 'eval' or 'terminate' 
method: Column 0 within a returned row had a value of None, either directly or 
within array/struct/map subfields, but the corresponding column type was 
declared as non-nullable; please update the UDTF to return a non-None value at 
this location or otherwise declare the column type as nullable.
+
+
 -- !query
 DROP VIEW t1
 -- !query schema
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index 3c30c414f81..aadeb6fcc8b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -628,7 +628,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
       "Python UDTF exporting input table partitioning and ordering requirement 
from 'analyze'"
   }
 
-  object TestPythonUDTFInvalidPartitionByAndWithSinglePartition extends 
TestUDTF {
+  object InvalidPartitionByAndWithSinglePartition extends TestUDTF {
     val name: String = "UDTFInvalidPartitionByAndWithSinglePartition"
     val pythonScript: String =
       s"""
@@ -668,7 +668,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
         "because the 'with_single_partition' property is also exported to true"
   }
 
-  object TestPythonUDTFInvalidOrderByWithoutPartitionBy extends TestUDTF {
+  object InvalidOrderByWithoutPartitionBy extends TestUDTF {
     val name: String = "UDTFInvalidOrderByWithoutPartitionBy"
     val pythonScript: String =
       s"""
@@ -749,6 +749,361 @@ object IntegratedUDFTestUtils extends SQLHelper {
     val prettyName: String = "Python UDTF whose 'analyze' method sets state 
and reads it later"
   }
 
+  object InvalidEvalReturnsNoneToNonNullableColumnScalarType extends TestUDTF {
+    val name: String = "InvalidEvalReturnsNoneToNonNullableColumnScalarType"
+    val pythonScript: String =
+      s"""
+         |from pyspark.sql.functions import AnalyzeResult
+         |from pyspark.sql.types import StringType, StructType
+         |
+         |class $name:
+         |    def __init__(self, analyze_result):
+         |        self._analyze_result = analyze_result
+         |
+         |    @staticmethod
+         |    def analyze(*args):
+         |        return AnalyzeResult(
+         |            schema=StructType()
+         |                .add("result", StringType(), False)
+         |            )
+         |
+         |    def eval(self, *args):
+         |        yield None,
+         |""".stripMargin
+
+    val udtf: UserDefinedPythonTableFunction = 
createUserDefinedPythonTableFunction(
+      name = name,
+      pythonScript = pythonScript,
+      returnType = None)
+
+    def apply(session: SparkSession, exprs: Column*): DataFrame =
+      udtf.apply(session, exprs: _*)
+
+    val prettyName: String =
+      "Invalid Python UDTF whose 'eval' method returns None to a non-nullable 
scalar column"
+  }
+
+  object InvalidEvalReturnsNoneToNonNullableColumnArrayType extends TestUDTF {
+    val name: String = "InvalidEvalReturnsNoneToNonNullableColumnArrayType"
+    val pythonScript: String =
+      s"""
+         |from pyspark.sql.functions import AnalyzeResult
+         |from pyspark.sql.types import ArrayType, IntegerType, StringType, 
StructType
+         |
+         |class $name:
+         |    def __init__(self, analyze_result):
+         |        self._analyze_result = analyze_result
+         |
+         |    @staticmethod
+         |    def analyze(*args):
+         |        return AnalyzeResult(
+         |            schema=StructType()
+         |                .add("result", ArrayType(IntegerType(), 
containsNull=True), False)
+         |            )
+         |
+         |    def eval(self, *args):
+         |        yield None,
+         |""".stripMargin
+
+    val udtf: UserDefinedPythonTableFunction = 
createUserDefinedPythonTableFunction(
+      name = name,
+      pythonScript = pythonScript,
+      returnType = None)
+
+    def apply(session: SparkSession, exprs: Column*): DataFrame =
+      udtf.apply(session, exprs: _*)
+
+    val prettyName: String =
+      "Invalid Python UDTF whose 'eval' method returns None to a non-nullable 
array column"
+  }
+
+  object InvalidEvalReturnsNoneToNonNullableColumnArrayElementType extends 
TestUDTF {
+    val name: String = 
"InvalidEvalReturnsNoneToNonNullableColumnArrayElementType"
+    val pythonScript: String =
+      s"""
+         |from pyspark.sql.functions import AnalyzeResult
+         |from pyspark.sql.types import ArrayType, IntegerType, StringType, 
StructType
+         |
+         |class $name:
+         |    def __init__(self, analyze_result):
+         |        self._analyze_result = analyze_result
+         |
+         |    @staticmethod
+         |    def analyze(*args):
+         |        return AnalyzeResult(
+         |            schema=StructType()
+         |                .add("result", ArrayType(IntegerType(), 
containsNull=False), True)
+         |            )
+         |
+         |    def eval(self, *args):
+         |        yield [1, 2, None, 3],
+         |""".stripMargin
+
+    val udtf: UserDefinedPythonTableFunction = 
createUserDefinedPythonTableFunction(
+      name = name,
+      pythonScript = pythonScript,
+      returnType = None)
+
+    def apply(session: SparkSession, exprs: Column*): DataFrame =
+      udtf.apply(session, exprs: _*)
+
+    val prettyName: String =
+      "Invalid Python UDTF whose 'eval' method returns None to a non-nullable 
array element"
+  }
+
+  object InvalidEvalReturnsNoneToNonNullableColumnStructType extends TestUDTF {
+    val name: String = "InvalidEvalReturnsNoneToNonNullableColumnStructType"
+    val pythonScript: String =
+      s"""
+         |from pyspark.sql.functions import AnalyzeResult
+         |from pyspark.sql.types import IntegerType, Row, StringType, 
StructType
+         |
+         |class $name:
+         |    def __init__(self, analyze_result):
+         |        self._analyze_result = analyze_result
+         |
+         |    @staticmethod
+         |    def analyze(*args):
+         |        return AnalyzeResult(
+         |            schema=StructType()
+         |                .add("result", StructType().add("field", 
IntegerType(), False), True)
+         |            )
+         |
+         |    def eval(self, *args):
+         |        yield Row(field=None),
+         |""".stripMargin
+
+    val udtf: UserDefinedPythonTableFunction = 
createUserDefinedPythonTableFunction(
+      name = name,
+      pythonScript = pythonScript,
+      returnType = None)
+
+    def apply(session: SparkSession, exprs: Column*): DataFrame =
+      udtf.apply(session, exprs: _*)
+
+    val prettyName: String =
+      "Invalid Python UDTF whose 'eval' method returns None to a non-nullable 
struct column"
+  }
+
+  object InvalidEvalReturnsNoneToNonNullableColumnMapType extends TestUDTF {
+    val name: String = "InvalidEvalReturnsNoneToNonNullableColumnMapType"
+    val pythonScript: String =
+      s"""
+         |from pyspark.sql.functions import AnalyzeResult
+         |from pyspark.sql.types import IntegerType, MapType, StringType, 
StructType
+         |
+         |class $name:
+         |    def __init__(self, analyze_result):
+         |        self._analyze_result = analyze_result
+         |
+         |    @staticmethod
+         |    def analyze(*args):
+         |        return AnalyzeResult(
+         |            schema=StructType()
+         |                .add("result", MapType(IntegerType(), StringType(), 
False), True)
+         |            )
+         |
+         |    def eval(self, *args):
+         |        yield {42: None},
+         |""".stripMargin
+
+    val udtf: UserDefinedPythonTableFunction = 
createUserDefinedPythonTableFunction(
+      name = name,
+      pythonScript = pythonScript,
+      returnType = None)
+
+    def apply(session: SparkSession, exprs: Column*): DataFrame =
+      udtf.apply(session, exprs: _*)
+
+    val prettyName: String =
+      "Invalid Python UDTF whose 'eval' method returns None to a non-nullable 
map column"
+  }
+
+  object InvalidTerminateReturnsNoneToNonNullableColumnScalarType extends 
TestUDTF {
+    val name: String = 
"InvalidTerminateReturnsNoneToNonNullableColumnScalarType"
+    val pythonScript: String =
+      s"""
+         |from pyspark.sql.functions import AnalyzeResult
+         |from pyspark.sql.types import StringType, StructType
+         |
+         |class $name:
+         |    def __init__(self, analyze_result):
+         |        self._analyze_result = analyze_result
+         |
+         |    @staticmethod
+         |    def analyze(*args):
+         |        return AnalyzeResult(
+         |            schema=StructType()
+         |                .add("result", StringType(), False)
+         |            )
+         |
+         |    def eval(self, *args):
+         |        yield 'abc',
+         |
+         |    def terminate(self):
+         |        yield None,
+         |""".stripMargin
+
+    val udtf: UserDefinedPythonTableFunction = 
createUserDefinedPythonTableFunction(
+      name = name,
+      pythonScript = pythonScript,
+      returnType = None)
+
+    def apply(session: SparkSession, exprs: Column*): DataFrame =
+      udtf.apply(session, exprs: _*)
+
+    val prettyName: String =
+      "Invalid Python UDTF whose 'terminate' method returns None to a 
non-nullable column"
+  }
+
+  object InvalidTerminateReturnsNoneToNonNullableColumnArrayType extends 
TestUDTF {
+    val name: String = 
"InvalidTerminateReturnsNoneToNonNullableColumnArrayType"
+    val pythonScript: String =
+      s"""
+         |from pyspark.sql.functions import AnalyzeResult
+         |from pyspark.sql.types import ArrayType, IntegerType, StringType, 
StructType
+         |
+         |class $name:
+         |    def __init__(self, analyze_result):
+         |        self._analyze_result = analyze_result
+         |
+         |    @staticmethod
+         |    def analyze(*args):
+         |        return AnalyzeResult(
+         |            schema=StructType()
+         |                .add("result", ArrayType(IntegerType(), 
containsNull=True), False)
+         |            )
+         |
+         |    def eval(self, *args):
+         |        yield [1, 2, 3, 4],
+         |
+         |    def terminate(self):
+         |        yield None,
+         |""".stripMargin
+
+    val udtf: UserDefinedPythonTableFunction = 
createUserDefinedPythonTableFunction(
+      name = name,
+      pythonScript = pythonScript,
+      returnType = None)
+
+    def apply(session: SparkSession, exprs: Column*): DataFrame =
+      udtf.apply(session, exprs: _*)
+
+    val prettyName: String =
+      "Invalid Python UDTF whose 'terminate' method returns None to a 
non-nullable array column"
+  }
+
+  object InvalidTerminateReturnsNoneToNonNullableColumnArrayElementType 
extends TestUDTF {
+    val name: String = 
"InvalidTerminateReturnsNoneToNonNullableColumnArrayElementType"
+    val pythonScript: String =
+      s"""
+         |from pyspark.sql.functions import AnalyzeResult
+         |from pyspark.sql.types import ArrayType, IntegerType, StringType, 
StructType
+         |
+         |class $name:
+         |    def __init__(self, analyze_result):
+         |        self._analyze_result = analyze_result
+         |
+         |    @staticmethod
+         |    def analyze(*args):
+         |        return AnalyzeResult(
+         |            schema=StructType()
+         |                .add("result", ArrayType(IntegerType(), 
containsNull=False), True)
+         |            )
+         |
+         |    def eval(self, *args):
+         |        yield [1, 2, 3, 4],
+         |
+         |    def terminate(self):
+         |        yield [1, 2, None, 3],
+         |""".stripMargin
+
+    val udtf: UserDefinedPythonTableFunction = 
createUserDefinedPythonTableFunction(
+      name = name,
+      pythonScript = pythonScript,
+      returnType = None)
+
+    def apply(session: SparkSession, exprs: Column*): DataFrame =
+      udtf.apply(session, exprs: _*)
+
+    val prettyName: String =
+      "Invalid Python UDTF whose 'terminate' method returns None to a 
non-nullable array element"
+  }
+
+  object InvalidTerminateReturnsNoneToNonNullableColumnStructType extends 
TestUDTF {
+    val name: String = 
"InvalidTerminateReturnsNoneToNonNullableColumnStructType"
+    val pythonScript: String =
+      s"""
+         |from pyspark.sql.functions import AnalyzeResult
+         |from pyspark.sql.types import IntegerType, Row, StringType, 
StructType
+         |
+         |class $name:
+         |    def __init__(self, analyze_result):
+         |        self._analyze_result = analyze_result
+         |
+         |    @staticmethod
+         |    def analyze(*args):
+         |        return AnalyzeResult(
+         |            schema=StructType()
+         |                .add("result", StructType().add("field", 
IntegerType(), False), True)
+         |            )
+         |
+         |    def eval(self, *args):
+         |       yield Row(field=42),
+         |
+         |    def terminate(self):
+         |       yield Row(field=None),
+         |""".stripMargin
+
+    val udtf: UserDefinedPythonTableFunction = 
createUserDefinedPythonTableFunction(
+      name = name,
+      pythonScript = pythonScript,
+      returnType = None)
+
+    def apply(session: SparkSession, exprs: Column*): DataFrame =
+      udtf.apply(session, exprs: _*)
+
+    val prettyName: String =
+      "Invalid Python UDTF whose 'terminate' method returns None to a 
non-nullable struct column"
+  }
+
+  object InvalidTerminateReturnsNoneToNonNullableColumnMapType extends 
TestUDTF {
+    val name: String = "InvalidTerminateReturnsNoneToNonNullableColumnMapType"
+    val pythonScript: String =
+      s"""
+         |from pyspark.sql.functions import AnalyzeResult
+         |from pyspark.sql.types import IntegerType, MapType, StringType, 
StructType
+         |
+         |class $name:
+         |    def __init__(self, analyze_result):
+         |        self._analyze_result = analyze_result
+         |
+         |    @staticmethod
+         |    def analyze(*args):
+         |        return AnalyzeResult(
+         |            schema=StructType()
+         |                .add("result", MapType(IntegerType(), StringType(), 
False), True)
+         |            )
+         |
+         |    def eval(self, *args):
+         |       yield {42: 'abc'},
+         |
+         |    def terminate(self):
+         |        yield {42: None},
+         |""".stripMargin
+
+    val udtf: UserDefinedPythonTableFunction = 
createUserDefinedPythonTableFunction(
+      name = name,
+      pythonScript = pythonScript,
+      returnType = None)
+
+    def apply(session: SparkSession, exprs: Column*): DataFrame =
+      udtf.apply(session, exprs: _*)
+
+    val prettyName: String =
+      "Invalid Python UDTF whose 'terminate' method returns None to a 
non-nullable map column"
+  }
+
   /**
    * A Scalar Pandas UDF that takes one column, casts into string, executes the
    * Python native function, and casts back to the type of input column.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
index 36899d11578..226d5098d42 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -541,19 +541,19 @@ class SQLQueryTestSuite extends QueryTest with 
SharedSparkSession with SQLHelper
         case _: AnalyzerTest =>
           val (_, output) =
             
handleExceptions(getNormalizedQueryAnalysisResult(localSparkSession, sql))
-          // We might need to do some query canonicalization in the future.
+          // We do some query canonicalization now.
           AnalyzerOutput(
             sql = sql,
             schema = None,
-            output = output.mkString("\n").replaceAll("\\s+$", ""))
+            output = normalizeTestResults(output.mkString("\n")))
         case _ =>
           val (schema, output) =
             
handleExceptions(getNormalizedQueryExecutionResult(localSparkSession, sql))
-          // We might need to do some query canonicalization in the future.
+          // We do some query canonicalization now.
           val executionOutput = ExecutionOutput(
             sql = sql,
             schema = Some(schema),
-            output = output.mkString("\n").replaceAll("\\s+$", ""))
+            output = normalizeTestResults(output.mkString("\n")))
           if (testCase.isInstanceOf[CTETest]) {
             expandCTEQueryAndCompareResult(localSparkSession, sql, 
executionOutput)
           }
@@ -650,8 +650,18 @@ class SQLQueryTestSuite extends QueryTest with 
SharedSparkSession with SQLHelper
           TestPythonUDTFLastString,
           TestPythonUDTFWithSinglePartition,
           TestPythonUDTFPartitionBy,
-          TestPythonUDTFInvalidPartitionByAndWithSinglePartition,
-          TestPythonUDTFInvalidOrderByWithoutPartitionBy
+          InvalidPartitionByAndWithSinglePartition,
+          InvalidOrderByWithoutPartitionBy,
+          InvalidEvalReturnsNoneToNonNullableColumnScalarType,
+          InvalidEvalReturnsNoneToNonNullableColumnArrayType,
+          InvalidEvalReturnsNoneToNonNullableColumnArrayElementType,
+          InvalidEvalReturnsNoneToNonNullableColumnStructType,
+          InvalidEvalReturnsNoneToNonNullableColumnMapType,
+          InvalidTerminateReturnsNoneToNonNullableColumnScalarType,
+          InvalidTerminateReturnsNoneToNonNullableColumnArrayType,
+          InvalidTerminateReturnsNoneToNonNullableColumnArrayElementType,
+          InvalidTerminateReturnsNoneToNonNullableColumnStructType,
+          InvalidTerminateReturnsNoneToNonNullableColumnMapType
         ))).map { udtfSet =>
           UDTFSetTestCase(
             s"$testCaseName - Python UDTFs", absPath, resultFile, udtfSet)
@@ -848,17 +858,18 @@ class SQLQueryTestSuite extends QueryTest with 
SharedSparkSession with SQLHelper
         s"Expected $numSegments blocks in result file but got " +
           s"${segments.size}. Try regenerate the result files.")
       var curSegment = 0
+
       outputs.map { output =>
         val result = if (output.numSegments == 3) {
           makeOutput(
             segments(curSegment + 1).trim, // SQL
             Some(segments(curSegment + 2).trim), // Schema
-            segments(curSegment + 3).replaceAll("\\s+$", "")) // Output
+            normalizeTestResults(segments(curSegment + 3))) // Output
         } else {
           makeOutput(
             segments(curSegment + 1).trim, // SQL
             None, // Schema
-            segments(curSegment + 2).replaceAll("\\s+$", "")) // Output
+            normalizeTestResults(segments(curSegment + 2))) // Output
         }
         curSegment += output.numSegments
         result
@@ -885,6 +896,22 @@ class SQLQueryTestSuite extends QueryTest with 
SharedSparkSession with SQLHelper
     }
   }
 
+  /** This is a helper function to normalize non-deterministic Python error 
stacktraces. */
+  def normalizeTestResults(output: String): String = {
+    val strippedPythonErrors: String = {
+      var traceback = false
+      output.split("\n").filter { line: String =>
+        if (line == "Traceback (most recent call last):") {
+          traceback = true
+        } else if (!line.startsWith(" ")) {
+          traceback = false
+        }
+        !traceback
+      }.mkString("\n")
+    }
+    strippedPythonErrors.replaceAll("\\s+$", "")
+  }
+
   /** A single SQL query's output. */
   trait QueryTestOutput {
     def sql: String


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to