This is an automated email from the ASF dual-hosted git repository. cloud-fan pushed a commit to branch branch-4.x in repository https://gitbox.apache.org/repos/asf/spark.git
commit b40841930601557d7e1e9a807261e9ce3d0772c9 Author: Andreas Chatzistergiou <[email protected]> AuthorDate: Wed May 20 09:34:43 2026 -0700 [SPARK-56676][SQL][DML] DSv2 Transactional Streaming Writes need to Validate Target between Microbatches ### What changes were proposed in this pull request? This PR addresses post-merge comments to the Transaction API: https://github.com/apache/spark/pull/55278. The focus is on improving streaming use cases. In particular, for transactional catalogs the streaming target is created as a v2 table reference so we can detect any table changes between micro batches. ### Why are the changes needed? We need to detect any changes of the write target in each micro batch. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added new tests for streaming use cases. ### Was this patch authored or co-authored using generative AI tooling? Claude Sonnet 4.6. Closes #55623 from andreaschat-db/dsv2TransactionApiImprovements. Authored-by: Andreas Chatzistergiou <[email protected]> Signed-off-by: Anton Okolnychyi <[email protected]> (cherry picked from commit 81988964645132fae0857429743ead0bd0702160) --- .../spark/sql/connector/write/BatchWrite.java | 7 +- .../connector/write/streaming/StreamingWrite.java | 8 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 56 +++++---- .../sql/catalyst/analysis/RelationResolution.scala | 6 +- .../sql/catalyst/analysis/V2TableReference.scala | 32 +++-- .../sql/catalyst/plans/logical/v2Commands.scala | 13 +-- .../catalyst/transactions/TransactionUtils.scala | 4 +- .../sql/connector/catalog/LookupCatalog.scala | 7 +- .../AnalyzerExtensionPropagationSuite.scala | 15 +++ .../apache/spark/sql/connector/catalog/txns.scala | 2 +- .../spark/sql/execution/QueryExecution.scala | 20 ++-- .../sql/execution/datasources/v2/V2Writes.scala | 21 ++-- .../streaming/runtime/MicroBatchExecution.scala | 16 ++- .../sources/WriteToMicroBatchDataSource.scala | 25 ++-- .../connector/PathBasedTableTransactionSuite.scala | 73 +++++++++++- .../sql/connector/StreamingTransactionSuite.scala | 130 ++++++++++++++++++--- 16 files changed, 324 insertions(+), 111 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java index 75816349af38..359cb7a354aa 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java @@ -87,9 +87,10 @@ public interface BatchWrite { * passed to this commit method. The remaining commit messages are ignored by Spark. * <p> * Note: this method signals that all data for this write operation has been successfully written. - * It is NOT a transactional commit. When this write is part of a - * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction}, the transaction is - * committed separately via + * When this write is part of a + * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction}, connector + * implementations should stage the written data durably but must not make it visible to readers. + * Changes are propagated and made visible only when the enclosing transaction is committed via * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction#commit()}. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java index 764ed0a35a3f..f4759e675a5c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.write.streaming; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.catalog.transactions.Transaction; import org.apache.spark.sql.connector.write.DataWriter; import org.apache.spark.sql.connector.write.PhysicalWriteInfo; import org.apache.spark.sql.connector.write.WriterCommitMessage; @@ -82,10 +83,9 @@ public interface StreamingWrite { * multiple commits for the same epoch are idempotent. * <p> * Note: this method signals that all data for this write operation has been successfully written. - * It is NOT a transactional commit. When this write is part of a - * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction}, the transaction is - * committed separately via - * {@link org.apache.spark.sql.connector.catalog.transactions.Transaction#commit()}. + * When this write is part of a {@link Transaction}, connector implementations should stage the + * written data durably but must not make it visible to readers. Changes are propagated and made + * visible only when the enclosing transaction is committed via {@link Transaction#commit()}. */ void commit(long epochId, WriterCommitMessage[] messages); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a6137e9a6c8f..d123d36c23b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1071,37 +1071,12 @@ class Analyzer( } } - // Resolve the write target of a V2 write command (batch or streaming). - private def resolveWriteTarget( - write: LogicalPlan, - table: NamedRelation, - withNewTable: NamedRelation => LogicalPlan): LogicalPlan = { - table match { - case u: UnresolvedRelation if !u.isStreaming => - resolveRelation(u).map(unwrapRelationPlan).map { - case v: View => throw QueryCompilationErrors.writeIntoViewNotAllowedError( - v.desc.identifier, write) - case u: UnresolvedCatalogRelation => - throw QueryCompilationErrors.writeIntoV1TableNotAllowedError( - u.tableMeta.identifier, write) - case r: DataSourceV2Relation => withNewTable(r) - case _ => - throw QueryCompilationErrors.writeIntoTempViewNotAllowedError( - u.multipartIdentifier.quoted) - }.getOrElse(write) - case _ => write - } - } - // Resolve V2TableReference nodes inside temp view plans. These are created by // V2TableReference.createForTempView. We only need to resolve it when returning // the plan of temp views (in resolveViews and unwrapRelationPlan). private def resolveTableReferencesInTempView(plan: LogicalPlan): LogicalPlan = { plan.resolveOperatorsUp { - case r: V2TableReference => - assert(r.context.isInstanceOf[V2TableReference.TemporaryViewContext], - s"""Expected TemporaryViewContext in temp view but got - |${r.context.getClass.getSimpleName}""".stripMargin) + case r: V2TableReference if r.context.isInstanceOf[V2TableReference.TemporaryViewContext] => relationResolution.resolveReference(r) } } @@ -1125,11 +1100,34 @@ class Analyzer( case other => i.copy(table = other) } - case write: StreamingV2WriteCommand => - resolveWriteTarget(write, write.table, write.withNewTable) + case write: V2StreamingWriteCommand => + write.table match { + case ref: V2TableReference => + relationResolution.resolveReference(ref) match { + case r: NamedRelation => write.withNewTable(r) + case other => throw SparkException.internalError( + s"Expected V2TableReference write target to resolve to a NamedRelation, " + + s"but got ${other.getClass.getName}") + } + case _ => write + } case write: V2WriteCommand => - resolveWriteTarget(write, write.table, write.withNewTable) + write.table match { + case u: UnresolvedRelation if !u.isStreaming => + resolveRelation(u).map(unwrapRelationPlan).map { + case v: View => throw QueryCompilationErrors.writeIntoViewNotAllowedError( + v.desc.identifier, write) + case u: UnresolvedCatalogRelation => + throw QueryCompilationErrors.writeIntoV1TableNotAllowedError( + u.tableMeta.identifier, write) + case r: DataSourceV2Relation => write.withNewTable(r) + case _ => + throw QueryCompilationErrors.writeIntoTempViewNotAllowedError( + u.multipartIdentifier.quoted) + }.getOrElse(write) + case _ => write + } case u: UnresolvedRelation => resolveRelation(u).map(resolveViews(_, u.options)).getOrElse(u) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index 462971eac066..55a7ad10790e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -458,7 +458,11 @@ class RelationResolution( } def resolveReference(ref: V2TableReference): LogicalPlan = { - val relation = getOrLoadRelation(ref) + val relation = if (ref.context.cacheable) { + getOrLoadRelation(ref) + } else { + loadRelation(ref) + } val planId = ref.getTagValue(LogicalPlan.PLAN_ID_TAG) cloneWithPlanId(relation, planId) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala index 6baef6f9bed6..223e7012af6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.V2TableReference.Context import org.apache.spark.sql.catalyst.analysis.V2TableReference.TableInfo import org.apache.spark.sql.catalyst.analysis.V2TableReference.TemporaryViewContext import org.apache.spark.sql.catalyst.analysis.V2TableReference.TransactionContext +import org.apache.spark.sql.catalyst.analysis.V2TableReference.WriteTargetContext import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics @@ -84,11 +85,24 @@ private[sql] object V2TableReference { columns: Seq[Column], metadataColumns: Seq[MetadataColumn]) - sealed trait Context + sealed trait Context { + def cacheable: Boolean + } + /** Context for relations that are re-resolved on access of a dataframe temp view. */ - case class TemporaryViewContext(viewName: Seq[String]) extends Context + case class TemporaryViewContext(viewName: Seq[String]) extends Context { + val cacheable = true + } + /** Context for relations that are re-resolved through a transaction catalog. */ - case object TransactionContext extends Context + case object TransactionContext extends Context { + val cacheable = true + } + + /** Context for write targets. */ + case object WriteTargetContext extends Context { + val cacheable = false + } def createForTempView(relation: DataSourceV2Relation, viewName: Seq[String]): V2TableReference = { create(relation, TemporaryViewContext(viewName)) @@ -100,13 +114,17 @@ private[sql] object V2TableReference { create(relation, TransactionContext) } + def createForWriteTarget(relation: DataSourceV2Relation): V2TableReference = { + create(relation, WriteTargetContext) + } + private def create(relation: DataSourceV2Relation, context: Context): V2TableReference = { val ref = V2TableReference( relation.catalog.get.asTableCatalog, relation.identifier.get, relation.options, TableInfo( - tableId = Option(relation.table.id()), + tableId = Option(relation.table.id), columns = relation.table.columns.toImmutableArraySeq, metadataColumns = V2TableUtil.extractMetadataColumns(relation)), relation.output, @@ -122,14 +140,14 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { ref.context match { case ctx: TemporaryViewContext => validateLoadedTableInTempView(table, ref, ctx) - case TransactionContext => - validateLoadedTableInTransaction(table, ref) + case TransactionContext | WriteTargetContext => + validateNoChanges(table, ref) case ctx => throw SparkException.internalError(s"Unknown table ref context: ${ctx.getClass.getName}") } } - private def validateLoadedTableInTransaction(table: Table, ref: V2TableReference): Unit = { + private def validateNoChanges(table: Table, ref: V2TableReference): Unit = { // Make sure the table was not dropped and recreated. ref.info.tableId.foreach(V2TableUtil.validateTableId(ref.name, _, table)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 5dd2f10c89cb..40cf5009b97d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -154,6 +154,12 @@ trait V2WriteCommand def withNewTable(newTable: NamedRelation): V2WriteCommand } +/** Trait for streaming write commands that participate in DSv2 transactions. */ +trait V2StreamingWriteCommand extends TransactionalWrite { + override def table: NamedRelation + def withNewTable(newTable: NamedRelation): V2StreamingWriteCommand +} + trait V2PartitionCommand extends UnaryCommand { def table: LogicalPlan def allowPartialPartitionSpec: Boolean = false @@ -1085,7 +1091,6 @@ case class MergeIntoTable( with SupportsSubquery with TransactionalWrite { - // Implements WriteWithSchemaEvolution.table and TransactionalWrite.table. override val table: LogicalPlan = EliminateSubqueryAliases(targetTable) override def withNewTable(newTable: NamedRelation): MergeIntoTable = { @@ -1355,12 +1360,6 @@ trait TransactionalWrite extends LogicalPlan { def table: LogicalPlan } -/** Trait for streaming write commands that participate in DSv2 transactions. */ -trait StreamingV2WriteCommand extends TransactionalWrite { - override def table: NamedRelation - def withNewTable(newTable: NamedRelation): StreamingV2WriteCommand -} - /** * The logical plan of the DROP TABLE command. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala index a5f8afddf01c..b59733df0d34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/transactions/TransactionUtils.scala @@ -47,8 +47,8 @@ object TransactionUtils { if (txn.catalog.name != catalog.name) { abort(txn) throw SparkException.internalError( - s"""Transaction catalog name (${txn.catalog.name}) - |must match original catalog name (${catalog.name}).""".stripMargin) + s"Transaction catalog name (${txn.catalog.name}) " + + s"must match original catalog name (${catalog.name}).") } txn } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index dd5be45bfc5f..14c066373032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connector.catalog import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedIdentifier, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedIdentifier, UnresolvedRelation, V2TableReference} import org.apache.spark.sql.catalyst.plans.logical.TransactionalWrite import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -173,6 +173,11 @@ private[sql] trait LookupCatalog extends Logging { Some(c) case UnresolvedIdentifier(CatalogAndIdentifier(c: TransactionalCatalogPlugin, _), _) => Some(c) + case ref: V2TableReference => + ref.catalog match { + case c: TransactionalCatalogPlugin => Some(c) + case _ => None + } case _ => None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala index 02cfe6b4eb7e..65ab822ec841 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalyzerExtensionPropagationSuite.scala @@ -54,6 +54,21 @@ class AnalyzerExtensionPropagationSuite extends SparkFunSuite { new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry)) test("withCatalogManager propagates all extension points") { + // Counts every declared field on Analyzer (backing fields for vals, + // constructor params, and fields inherited from mixed-in traits). When this assertion fails, + // a field was added to or removed from Analyzer. If the change is a new extension point, + // add it to Analyzer.withCatalogManager, add an assertion in the clone checks below, + // and update EXPECTED_FIELD_COUNT. + val EXPECTED_FIELD_COUNT = 12 + val analyzerFields = classOf[Analyzer].getDeclaredFields + .filterNot(f => f.isSynthetic || f.getName.contains("$")) + assert(analyzerFields.length == EXPECTED_FIELD_COUNT, + s"Analyzer has ${analyzerFields.length} declared fields " + + s"(${analyzerFields.map(_.getName).sorted.mkString(", ")}), " + + s"but expected $EXPECTED_FIELD_COUNT. " + + s"If a new extension point was added, register it in Analyzer.withCatalogManager, " + + s"add an assertion in this test, and update EXPECTED_FIELD_COUNT.") + val analyzer = new Analyzer(newCatalogManager()) { override val hintResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule) override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq(dummyRule) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 3adf38f6a218..4b9dff5c3d78 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -177,8 +177,8 @@ class TxnTableCatalog(delegate: InMemoryRowLevelOperationTableCatalog) extends T throw new IllegalArgumentException(s"Cannot drop all fields") } - // TODO: We need to pass all tracked predicates to the new TXN table. val newTxnTable = new TxnTable(txnTable.delegate, schema, this) + newTxnTable.scanEvents ++= txnTable.scanEvents tables.put(ident, newTxnTable) newTxnTable } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5bfb73e6b887..65bc57de907b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -114,7 +114,7 @@ class QueryExecution( // should keep state about the reads (tables+predicates) that occurred during the transaction. // 3. The analyzer instance is passed to nested Query Execution instances. These need to respect // the open transaction instead of creating their own. - private lazy val transactionOpt: Option[Transaction] = + private val lazyTransactionOpt = LazyTry { // Always inherit an active transaction from the outer analyzer, regardless of mode. analyzerOpt.flatMap(_.catalogManager.transaction).orElse { // Only begin a new transaction for outer QEs that lead to execution. @@ -136,6 +136,8 @@ class QueryExecution( None } } + } + private def transactionOpt: Option[Transaction] = lazyTransactionOpt.get // For path-based tables (e.g. `format.`/path/to/table``) the first identifier part is a // connector name. SupportsCatalogOptions on the connector tells us which catalog actually @@ -169,14 +171,18 @@ class QueryExecution( // so that all catalog lookups and rule applications during analysis see the correct state // without relying on thread-local context. Any nested QueryExecution that is created during // analysis or execution of a transactional plan must receive this analyzer via analyzerOpt. - private lazy val analyzer: Analyzer = analyzerOpt.getOrElse { - transactionOpt match { - case Some(txn) => - sparkSession.sessionState.analyzer.withCatalogManager(catalogManager.withTransaction(txn)) - case None => - sparkSession.sessionState.analyzer + private val lazyAnalyzer = LazyTry { + analyzerOpt.getOrElse { + transactionOpt match { + case Some(txn) => + sparkSession.sessionState.analyzer.withCatalogManager( + catalogManager.withTransaction(txn)) + case None => + sparkSession.sessionState.analyzer + } } } + private def analyzer: Analyzer = lazyAnalyzer.get def assertAnalyzed(): Unit = { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index 0cbf260457ff..be8e96e8034d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertOnlyMerge, import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.catalyst.util.WriteDeltaProjections -import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} +import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.write.{DeltaWriteBuilder, LogicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwriteV2, SupportsTruncate, Write, WriteBuilder} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -98,18 +98,15 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { o.copy(write = Some(write), query = newQuery) case WriteToMicroBatchDataSource( - relation, query, queryId, options, outputMode, Some(batchId)) => - val v2Relation = relation.asInstanceOf[DataSourceV2Relation] - val writeOptions = mergeOptions(options, v2Relation.options.asCaseSensitiveMap.asScala.toMap) - // Guaranteed to support writes since it is a strict requirement to construct - // WriteToMicroBatchDataSource. - val writeTable = v2Relation.table.asInstanceOf[SupportsWrite] - val writeBuilder = newWriteBuilder(writeTable, writeOptions, query.schema, queryId = queryId) - val write = buildWriteForMicroBatch(writeTable, writeBuilder, outputMode) + r: DataSourceV2Relation, query, queryId, options, outputMode, Some(batchId)) => + val table = r.table + val writeOptions = mergeOptions(options, r.options.asCaseSensitiveMap.asScala.toMap) + val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId = queryId) + val write = buildWriteForMicroBatch(table, writeBuilder, outputMode) val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming) val customMetrics = write.supportedCustomMetrics.toImmutableArraySeq - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, v2Relation.funCatalog) - WriteToDataSourceV2(Some(v2Relation), microBatchWrite, newQuery, customMetrics) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) + WriteToDataSourceV2(Some(r), microBatchWrite, newQuery, customMetrics) case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, projections, _, None) => val rowSchema = projections.rowProjection.schema @@ -139,7 +136,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { } private def buildWriteForMicroBatch( - table: SupportsWrite, + table: Table, writeBuilder: WriteBuilder, outputMode: OutputMode): Write = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala index d4fae034e797..b499e676a84a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkIllegalStateException} import org.apache.spark.internal.LogKeys import org.apache.spark.internal.LogKeys._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.analysis.V2TableReference import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Deduplicate, DeduplicateWithinWatermark, Distinct, FlatMapGroupsInPandasWithState, FlatMapGroupsWithState, GlobalLimit, Join, LeafNode, LocalRelation, LogicalPlan, Project, StreamSourceAwareLogicalPlan, TransformWithState, TransformWithStateInPySpark} @@ -365,11 +365,17 @@ class MicroBatchExecution( sink match { case s: SupportsWrite => val relation = plan.catalogAndIdent match { - // When the catalog is transactional, instead of eagerly creating the relation, we - // delegate resolution to ResolveRelations. This allows to resolve the relation against - // a transactional catalo which keeps track of all tables loaded within the transaction. + // For transactional catalog sinks, capture the baseline table metadata in a + // V2TableReference so that each micro-batch re-resolves the table through the + // transaction-aware catalog and fails if the table has been replaced or schema changed. case Some((catalog: TransactionalCatalogPlugin, ident)) => - UnresolvedRelation(catalog.name +: ident.namespace().toSeq :+ ident.name()) + // Re-resolve through the streaming session's catalog manager so the reference + // captures the streaming-session-specific catalog instance. TransactionalWrite + // detection and transaction begin must happen in the streaming session context. + val catalogManager = sparkSessionForStream.sessionState.catalogManager + val streamingCatalog = catalogManager.catalog(catalog.name) + val v2Relation = DataSourceV2Relation.create(s, Some(streamingCatalog), Some(ident)) + V2TableReference.createForWriteTarget(v2Relation) case Some((catalog, ident)) => DataSourceV2Relation.create(s, Some(catalog), Some(ident)) case None => DataSourceV2Relation.create(s, None, None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala index 7aa7a31bb085..5f8c53df08d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala @@ -19,24 +19,25 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.catalyst.analysis.NamedRelation import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, StreamingV2WriteCommand, UnaryNode} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode, V2StreamingWriteCommand} import org.apache.spark.sql.streaming.OutputMode /** * The logical plan for writing data to a micro-batch stream. * - * Note that this logical plan does not have a corresponding physical plan, as it will be converted - * to [[org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 WriteToDataSourceV2]] + * Note that this logical plan does not have a corresponding physical plan, as it will be + * converted to + * [[org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 WriteToDataSourceV2]] * with [[MicroBatchWrite]] before execution. * - * [[relation]] starts as [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation]] when the - * sink has a catalog+identifier (transactional catalogs), or as a resolved - * [[org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation]] for non-transactional - * catalog-backed sinks and format-based sinks. - * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveRelations]] - * resolves it to [[org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation]] during - * each micro-batch analysis, going through the transaction-aware catalog when a transaction is - * active. + * When the write target is backed by a transactional catalog, it is created as a + * [[org.apache.spark.sql.catalyst.analysis.V2TableReference V2TableReference]]. + * This is then resolved by ResolveRelations as a + * [[org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation DataSourceV2Relation]] + * for each micro-batch. + * + * For non-transactional catalogs, the write target is pre-resolved as a + * [[org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation DataSourceV2Relation]]. */ case class WriteToMicroBatchDataSource( relation: NamedRelation, @@ -45,7 +46,7 @@ case class WriteToMicroBatchDataSource( writeOptions: Map[String, String], outputMode: OutputMode, batchId: Option[Long] = None) - extends UnaryNode with StreamingV2WriteCommand { + extends UnaryNode with V2StreamingWriteCommand { override def child: LogicalPlan = query override def output: Seq[Attribute] = Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala index c6b2f33c25fe..c81f53673af3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/PathBasedTableTransactionSuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.Row -import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTableCatalog, InMemoryTableCatalog, SessionConfigSupport, SupportsCatalogOptions} +import org.apache.spark.sql.connector.catalog.{Aborted, Committed, Identifier, InMemoryRowLevelOperationTableCatalog, InMemoryTableCatalog, SessionConfigSupport, SharedTablesInMemoryRowLevelOperationTableCatalog, SupportsCatalogOptions} +import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamingQueryWrapper} import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -32,6 +34,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap */ class PathBasedTableTransactionSuite extends RowLevelOperationSuiteBase { + import testImplicits._ + private val tablePath = "`/path/to/t`" private val tablePathWithFormat = "pathformat.`/path/to/t`" @@ -39,10 +43,11 @@ class PathBasedTableTransactionSuite extends RowLevelOperationSuiteBase { super.beforeEach() spark.conf.set( V2_SESSION_CATALOG_IMPLEMENTATION.key, - classOf[InMemoryRowLevelOperationTableCatalog].getName) + classOf[SharedTablesInMemoryRowLevelOperationTableCatalog].getName) } override def afterEach(): Unit = { + SharedTablesInMemoryRowLevelOperationTableCatalog.reset() spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) super.afterEach() } @@ -52,6 +57,12 @@ class PathBasedTableTransactionSuite extends RowLevelOperationSuiteBase { .asInstanceOf[InMemoryRowLevelOperationTableCatalog] } + private def streamSessionCatalog(query: StreamingQuery): InMemoryRowLevelOperationTableCatalog = { + val session = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sparkSessionForStream + session.sessionState.catalogManager.v2SessionCatalog + .asInstanceOf[InMemoryRowLevelOperationTableCatalog] + } + private def createPathTable(name: String): Unit = { sql(s"CREATE TABLE $name (id INT, data STRING)") } @@ -116,6 +127,64 @@ class PathBasedTableTransactionSuite extends RowLevelOperationSuiteBase { } } + test("streaming write to path-based table participates in transaction") { + sql(s"CREATE TABLE $tablePathWithFormat (value INT)") + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(tablePathWithFormat) + + inputData.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + + val streamCat = streamSessionCatalog(query) + val txn = streamCat.lastTransaction + assert(txn != null, "expected a transaction to have been committed") + assert(txn.currentState === Committed) + assert(txn.isClosed) + // Streaming must not add transactions to the main session catalog. + assert(catalog.observedTransactions.isEmpty) + checkAnswer(spark.table(tablePathWithFormat), Row(1) :: Row(2) :: Row(3) :: Nil) + } + } + + test("streaming self-join on path-based table is tracked as a scan event") { + sql(s"CREATE TABLE $tablePathWithFormat (value INT)") + sql(s"INSERT INTO $tablePathWithFormat VALUES (1), (2), (3)") + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val staticData = spark.read.table(tablePathWithFormat) + + val query = inputData.toDF() + .join(staticData, "value") + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(tablePathWithFormat) + + inputData.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + + val streamCat = streamSessionCatalog(query) + val txn = streamCat.lastTransaction + assert(txn != null, "expected a transaction to have been committed") + assert(txn.currentState === Committed) + assert(txn.isClosed) + // The path-based table is both the write target and a batch source in the same transaction. + assert(txn.catalog.txnTables.size === 1) + val txnTable = txn.catalog.txnTables.values.head + assert(txnTable.scanEvents.size === 1) + // Streaming must not add transactions to the main session catalog beyond the pre-existing + // INSERT transaction. + assert(catalog.observedTransactions.size === 1) + } + } + test("SQL insert with unregistered format produces analysis error and aborts transaction") { createPathTable(tablePathWithFormat) // "Unregistered" is not a known catalog and not registered data source. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala index 13b6267a28ff..d356197fa53c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/StreamingTransactionSuite.scala @@ -130,15 +130,23 @@ class StreamingTransactionSuite extends RowLevelOperationSuiteBase { } } - test("batch read from catalog-backed table inside streaming query is tracked as a scan event") { + for (isSelfScan <- Seq(true, false)) + test("batch read from catalog-backed table inside streaming query is tracked as a " + + s"scan event (isSelfScan=$isSelfScan)") { // Target table for the stream. createSimpleTable("value INT") - // Catalog-backed static table used as a batch (non-streaming) source. - val sourceIdent = Identifier.of(namespace, "source_table") - val srcColumns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT")) - catalog.createTable(sourceIdent, new TableInfo.Builder().withColumns(srcColumns).build()) - sql(s"INSERT INTO $sourceNameAsString VALUES (1), (2), (3)") + // Pick the static (non-streaming) source table. When isSelfScan is true, the stream's + // write target is also used as the batch source. + val staticSourceName = if (isSelfScan) { + tableNameAsString + } else { + val sourceIdent = Identifier.of(namespace, "source_table") + val srcColumns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT")) + catalog.createTable(sourceIdent, new TableInfo.Builder().withColumns(srcColumns).build()) + sourceNameAsString + } + sql(s"INSERT INTO $staticSourceName VALUES (1), (2), (3)") // The INSERT above runs a transaction on the main session catalog; capture the count now // so we can assert the streaming query does not add more. val mainTxnsBefore = catalog.observedTransactions.size @@ -149,7 +157,7 @@ class StreamingTransactionSuite extends RowLevelOperationSuiteBase { // spark.read produces a DataSourceV2Relation (batch), not a streaming source. // UnresolveRelationsInTransaction converts it to V2TableReference each micro-batch so // the transaction-aware catalog can record the scan event. - val staticData = spark.read.table(sourceNameAsString) + val staticData = spark.read.table(staticSourceName) val query = inputData.toDF() .join(staticData, "value") @@ -157,6 +165,9 @@ class StreamingTransactionSuite extends RowLevelOperationSuiteBase { .option("checkpointLocation", checkpointDir.getAbsolutePath) .toTable(tableNameAsString) + // There should be no transaction yet in the cloned session. + assert(streamCatalog(query).lastTransaction === null) + inputData.addData(1, 2, 3) query.processAllAvailable() query.stop() @@ -166,26 +177,109 @@ class StreamingTransactionSuite extends RowLevelOperationSuiteBase { assert(txn.currentState === Committed) assert(txn.isClosed) - // Both the write target and the batch source participate in the transaction. - assert(txn.catalog.txnTables.size === 2) + if (isSelfScan) { + // Target acts as both write target and batch source. + assert(txn.catalog.txnTables.size === 1) + val targetTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(tableNameAsString) + assert(targetTxnTable.scanEvents.size === 1) + } else { + // Both the write target and the batch source participate in the transaction. + assert(txn.catalog.txnTables.size === 2) + val targetTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(tableNameAsString) + assert(targetTxnTable.scanEvents.isEmpty) + // The static source was read exactly once and its scan event was captured. + val sourceTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(sourceNameAsString) + assert(sourceTxnTable.scanEvents.size === 1) + } - val targetTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(tableNameAsString) - assert(targetTxnTable.scanEvents.isEmpty) + // Streaming must not add transactions to the main session catalog beyond pre-existing + // setup transactions. + assert(catalog.observedTransactions.size === mainTxnsBefore) - // The static source was read exactly once and its scan event was captured. - val sourceTxnTable = indexByName(txn.catalog.txnTables.values.toSeq)(sourceNameAsString) - assert(sourceTxnTable.scanEvents.size === 1) + // In the self-scan case the target was pre-populated with 1,2,3 and the streaming append + // adds another 1,2,3 from the join, so the table ends with two copies of each value. + val expectedRows = if (isSelfScan) { + Seq(Row(1), Row(2), Row(3), Row(1), Row(2), Row(3)) + } else { + Seq(Row(1), Row(2), Row(3)) + } + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), expectedRows) + } + } - // Streaming must not add transactions to the main session catalog beyond the pre-existing - // INSERT transaction. - assert(catalog.observedTransactions.size === mainTxnsBefore) + test("micro-batch fails when target table schema changes between batches") { + createSimpleTable("value INT") + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val query = inputData.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(tableNameAsString) + + // Batch 1 succeeds against the original schema captured at query start. + inputData.addData(1, 2, 3) + query.processAllAvailable() + + val firstTxn = streamCatalog(query).lastTransaction + assert(firstTxn != null) + assert(firstTxn.currentState === Committed) + + // Mutate the target schema between micro-batches via the main session catalog. The + // shared in-memory backing store makes the change visible to the streaming session. + sql(s"ALTER TABLE $tableNameAsString ADD COLUMNS (extra STRING)") + + // Batch 2: re-resolution of the WriteTargetContext reference loads the altered table + // and validateNoChanges rejects the added column. + inputData.addData(4, 5, 6) + val ex = intercept[Exception] { query.processAllAvailable() } + assert(ex.getMessage.contains("INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS") || + Option(ex.getCause).exists( + _.getMessage.contains("INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS"))) + query.stop() + // Only batch 1's rows should be visible; batch 2 never wrote anything. checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), + sql(s"SELECT value FROM $tableNameAsString"), Seq(Row(1), Row(2), Row(3))) } } + test("micro-batch fails when batch source schema changes after capture") { + createSimpleTable("value INT") + + val sourceIdent = Identifier.of(namespace, "source_table") + val srcColumns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT")) + catalog.createTable(sourceIdent, new TableInfo.Builder().withColumns(srcColumns).build()) + sql(s"INSERT INTO $sourceNameAsString VALUES (1), (2), (3)") + + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + + // Capture the static source against its original schema. + val staticData = spark.read.table(sourceNameAsString) + + // Mutate the source schema after the static reference was captured. + sql(s"ALTER TABLE $sourceNameAsString ADD COLUMNS (extra STRING)") + + val query = inputData.toDF() + .join(staticData, "value") + .writeStream + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .toTable(tableNameAsString) + + inputData.addData(1, 2, 3) + val ex = intercept[Exception] { query.processAllAvailable() } + assert(ex.getMessage.contains("INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS") || + Option(ex.getCause).exists( + _.getMessage.contains("INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS"))) + query.stop() + + // Analysis failed before any commit. The target must remain empty. + checkAnswer(sql(s"SELECT * FROM $tableNameAsString"), Seq.empty) + } + } + test("transaction is aborted when micro-batch write fails and no data is written") { val columns = CatalogV2Util.structTypeToV2Columns(StructType.fromDDL("value INT")) val tableInfo = new TableInfo.Builder() --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
