This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dfb8834f9cad [SPARK-54768][SS] Python Stream Data Source should 
classify error if data returned doesn't match configured schema
dfb8834f9cad is described below

commit dfb8834f9cad21e6eae4f0a6275de6dab7cf4192
Author: Siying Dong <[email protected]>
AuthorDate: Tue Jan 6 14:39:39 2026 -0800

    [SPARK-54768][SS] Python Stream Data Source should classify error if data 
returned doesn't match configured schema
    
    ### What changes were proposed in this pull request?
    in Python Stream Data Source, if records returned by users don't match 
schema they provided, a classified error is thrown, instead of an assertion 
failure.
    
    ### Why are the changes needed?
    This will provide better experience to users who get the correct error 
message and code.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add a unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    Generated-by: Cursor 2.2.20 with claude-4.5-sonnet
    
    Closes #53535 from siying/source.
    
    Authored-by: Siying Dong <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 .../streaming/PythonStreamingSourceRunner.scala    |  5 +-
 .../streaming/PythonStreamingDataSourceSuite.scala | 57 +++++++++++++++++++++-
 2 files changed, 60 insertions(+), 2 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
index 5945b9108427..270d816e9bd9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
@@ -249,7 +249,10 @@ class PythonStreamingSourceRunner(
     val root = reader.getVectorSchemaRoot()
     // When input is empty schema can't be read.
     val schema = ArrowUtils.fromArrowSchema(root.getSchema())
-    assert(schema == outputSchema)
+    if (schema != outputSchema) {
+      throw QueryExecutionErrors.arrowDataTypeMismatchError(
+        "Python streaming data source read", Seq(outputSchema), Seq(schema))
+    }
 
     val vectors = root.getFieldVectors().asScala.map { vector =>
       new ArrowColumnVector(vector)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
index 3b3e8687858d..0e33b6e55a43 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
@@ -29,7 +29,7 @@ import 
org.apache.spark.sql.execution.python.PythonDataSourceSuiteBase
 import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger
 import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, 
OffsetSeqLog}
 import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
-import org.apache.spark.sql.streaming.StreamingQueryException
+import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger}
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
@@ -340,6 +340,61 @@ class PythonStreamingDataSourceSimpleSuite extends 
PythonDataSourceSuiteBase {
       stream.latestOffset()
     }
   }
+
+  test("SPARK-54768: SimpleDataSourceStreamReader schema mismatch - prefetched 
batches") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource, 
SimpleDataSourceStreamReader
+         |import pyarrow as pa
+         |
+         |class SchemaMismatchSimpleReader(SimpleDataSourceStreamReader):
+         |    def initialOffset(self):
+         |        return {"offset": 0}
+         |    def read(self, start: dict):
+         |        # Return PyArrow RecordBatch with STRING when INT is 
expected.
+         |        schema = pa.schema([pa.field("id", pa.string(), 
nullable=True)])
+         |        batch = pa.RecordBatch.from_arrays(
+         |            [pa.array(["1"], type=pa.string())], schema=schema)
+         |        return iter([batch]), {"offset": 1}
+         |    def readBetweenOffsets(self, start: dict, end: dict):
+         |        return iter([])
+         |    def commit(self, end: dict):
+         |        pass
+         |
+         |class $errorDataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT NOT NULL"
+         |    def simpleStreamReader(self, schema):
+         |        return SchemaMismatchSimpleReader()
+         |""".stripMargin
+
+    val dataSource = createUserDefinedPythonDataSource(errorDataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(errorDataSourceName, dataSource)
+
+    val df = spark.readStream.format(errorDataSourceName).load()
+    val err = intercept[StreamingQueryException] {
+      val q = df.writeStream
+        .trigger(Trigger.Once())
+        .foreachBatch((df: DataFrame, _: Long) => {
+          df.count()
+          ()
+        })
+        .start()
+      q.awaitTermination()
+    }
+    assert(err.getCause.isInstanceOf[SparkException])
+    val cause = err.getCause.asInstanceOf[SparkException]
+    checkErrorMatchPVals(
+      cause,
+      condition = "ARROW_TYPE_MISMATCH",
+      parameters = Map(
+        "operation" -> "Python streaming data source read",
+        "outputTypes" -> 
"StructType\\(StructField\\(id,IntegerType,false\\)\\)",
+        "actualDataTypes" -> 
"StructType\\(StructField\\(id,StringType,true\\)\\)"
+      )
+    )
+  }
 }
 
 class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to