This is an automated email from the ASF dual-hosted git repository.
ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dfb8834f9cad [SPARK-54768][SS] Python Stream Data Source should
classify error if data returned doesn't match configured schema
dfb8834f9cad is described below
commit dfb8834f9cad21e6eae4f0a6275de6dab7cf4192
Author: Siying Dong <[email protected]>
AuthorDate: Tue Jan 6 14:39:39 2026 -0800
[SPARK-54768][SS] Python Stream Data Source should classify error if data
returned doesn't match configured schema
### What changes were proposed in this pull request?
in Python Stream Data Source, if records returned by users don't match
schema they provided, a classified error is thrown, instead of an assertion
failure.
### Why are the changes needed?
This will provide better experience to users who get the correct error
message and code.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Add a unit test
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Cursor 2.2.20 with claude-4.5-sonnet
Closes #53535 from siying/source.
Authored-by: Siying Dong <[email protected]>
Signed-off-by: Anish Shrigondekar <[email protected]>
---
.../streaming/PythonStreamingSourceRunner.scala | 5 +-
.../streaming/PythonStreamingDataSourceSuite.scala | 57 +++++++++++++++++++++-
2 files changed, 60 insertions(+), 2 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
index 5945b9108427..270d816e9bd9 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
@@ -249,7 +249,10 @@ class PythonStreamingSourceRunner(
val root = reader.getVectorSchemaRoot()
// When input is empty schema can't be read.
val schema = ArrowUtils.fromArrowSchema(root.getSchema())
- assert(schema == outputSchema)
+ if (schema != outputSchema) {
+ throw QueryExecutionErrors.arrowDataTypeMismatchError(
+ "Python streaming data source read", Seq(outputSchema), Seq(schema))
+ }
val vectors = root.getFieldVectors().asScala.map { vector =>
new ArrowColumnVector(vector)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
index 3b3e8687858d..0e33b6e55a43 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
@@ -29,7 +29,7 @@ import
org.apache.spark.sql.execution.python.PythonDataSourceSuiteBase
import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger
import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog,
OffsetSeqLog}
import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
-import org.apache.spark.sql.streaming.StreamingQueryException
+import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -340,6 +340,61 @@ class PythonStreamingDataSourceSimpleSuite extends
PythonDataSourceSuiteBase {
stream.latestOffset()
}
}
+
+ test("SPARK-54768: SimpleDataSourceStreamReader schema mismatch - prefetched
batches") {
+ assume(shouldTestPandasUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource,
SimpleDataSourceStreamReader
+ |import pyarrow as pa
+ |
+ |class SchemaMismatchSimpleReader(SimpleDataSourceStreamReader):
+ | def initialOffset(self):
+ | return {"offset": 0}
+ | def read(self, start: dict):
+ | # Return PyArrow RecordBatch with STRING when INT is
expected.
+ | schema = pa.schema([pa.field("id", pa.string(),
nullable=True)])
+ | batch = pa.RecordBatch.from_arrays(
+ | [pa.array(["1"], type=pa.string())], schema=schema)
+ | return iter([batch]), {"offset": 1}
+ | def readBetweenOffsets(self, start: dict, end: dict):
+ | return iter([])
+ | def commit(self, end: dict):
+ | pass
+ |
+ |class $errorDataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "id INT NOT NULL"
+ | def simpleStreamReader(self, schema):
+ | return SchemaMismatchSimpleReader()
+ |""".stripMargin
+
+ val dataSource = createUserDefinedPythonDataSource(errorDataSourceName,
dataSourceScript)
+ spark.dataSource.registerPython(errorDataSourceName, dataSource)
+
+ val df = spark.readStream.format(errorDataSourceName).load()
+ val err = intercept[StreamingQueryException] {
+ val q = df.writeStream
+ .trigger(Trigger.Once())
+ .foreachBatch((df: DataFrame, _: Long) => {
+ df.count()
+ ()
+ })
+ .start()
+ q.awaitTermination()
+ }
+ assert(err.getCause.isInstanceOf[SparkException])
+ val cause = err.getCause.asInstanceOf[SparkException]
+ checkErrorMatchPVals(
+ cause,
+ condition = "ARROW_TYPE_MISMATCH",
+ parameters = Map(
+ "operation" -> "Python streaming data source read",
+ "outputTypes" ->
"StructType\\(StructField\\(id,IntegerType,false\\)\\)",
+ "actualDataTypes" ->
"StructType\\(StructField\\(id,StringType,true\\)\\)"
+ )
+ )
+ }
}
class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]