This is an automated email from the ASF dual-hosted git repository.

kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ecbb1bff407 Added insert and update on conflict to renderedtifields.py 
(#63874)
ecbb1bff407 is described below

commit ecbb1bff407f318218224495cd52a020013e183f
Author: manipatnam <[email protected]>
AuthorDate: Wed Jun 3 06:11:50 2026 +0530

    Added insert and update on conflict to renderedtifields.py (#63874)
    
    closes: #61705
---
 .../src/airflow/models/renderedtifields.py         | 31 +++++++++--
 airflow-core/src/airflow/models/variable.py        | 34 +-----------
 airflow-core/src/airflow/utils/sqlalchemy.py       | 50 ++++++++++++++++-
 .../tests/unit/models/test_renderedtifields.py     | 62 +++++++++++++++++++++-
 4 files changed, 140 insertions(+), 37 deletions(-)

diff --git a/airflow-core/src/airflow/models/renderedtifields.py 
b/airflow-core/src/airflow/models/renderedtifields.py
index d9b5f115b33..e405f3bfce7 100644
--- a/airflow-core/src/airflow/models/renderedtifields.py
+++ b/airflow-core/src/airflow/models/renderedtifields.py
@@ -38,7 +38,7 @@ from airflow.models.base import StringID, 
TaskInstanceDependencies
 from airflow.serialization.helpers import serialize_template_field
 from airflow.utils.retries import retry_db_transaction
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import get_dialect_name
+from airflow.utils.sqlalchemy import build_upsert_stmt, get_dialect_name
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
@@ -239,13 +239,38 @@ class 
RenderedTaskInstanceFields(TaskInstanceDependencies):
 
     @provide_session
     @retry_db_transaction
-    def write(self, session: Session):
+    def write(self, session: Session = NEW_SESSION):
         """
         Write instance to database.
 
+        Uses a database-level upsert (INSERT ... ON CONFLICT DO UPDATE) to
+        atomically insert or update the record, avoiding race conditions that
+        can occur with session.merge() when concurrent requests (e.g. from
+        client-side timeout retries) target the same primary key.
+
         :param session: SqlAlchemy Session
         """
-        session.merge(self)
+        values = {
+            "dag_id": self.dag_id,
+            "task_id": self.task_id,
+            "run_id": self.run_id,
+            "map_index": self.map_index,
+            "rendered_fields": self.rendered_fields,
+            "k8s_pod_yaml": self.k8s_pod_yaml,
+        }
+        update_on_conflict = {
+            "rendered_fields": self.rendered_fields,
+            "k8s_pod_yaml": self.k8s_pod_yaml,
+        }
+
+        stmt = build_upsert_stmt(
+            get_dialect_name(session),
+            RenderedTaskInstanceFields,
+            ["dag_id", "task_id", "run_id", "map_index"],
+            values,
+            update_on_conflict,
+        )
+        session.execute(stmt)
 
     @classmethod
     @provide_session
diff --git a/airflow-core/src/airflow/models/variable.py 
b/airflow-core/src/airflow/models/variable.py
index 667c3303567..eb50b92f988 100644
--- a/airflow-core/src/airflow/models/variable.py
+++ b/airflow-core/src/airflow/models/variable.py
@@ -49,44 +49,14 @@ except ImportError:
 from airflow.secrets.metastore import MetastoreBackend
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.sqlalchemy import get_dialect_name
+from airflow.utils.sqlalchemy import build_upsert_stmt, get_dialect_name
 
 if TYPE_CHECKING:
-    from sqlalchemy.dialects.mysql.dml import Insert as MySQLInsert
-    from sqlalchemy.dialects.postgresql.dml import Insert as PostgreSQLInsert
-    from sqlalchemy.dialects.sqlite.dml import Insert as SQLiteInsert
     from sqlalchemy.orm import Session
 
 log = logging.getLogger(__name__)
 
 
-def _build_variable_upsert_stmt(
-    dialect: str | None,
-    model: type[Variable],
-    conflict_cols: list[str],
-    values: dict[str, Any],
-    update_fields: dict[str, Any],
-) -> MySQLInsert | PostgreSQLInsert | SQLiteInsert:
-    """Return a dialect-specific INSERT ... ON CONFLICT UPDATE statement."""
-    stmt: MySQLInsert | PostgreSQLInsert | SQLiteInsert
-    if dialect == "postgresql":
-        from sqlalchemy.dialects.postgresql import insert as pg_insert
-
-        stmt = pg_insert(model).values(**values)
-        stmt = stmt.on_conflict_do_update(index_elements=conflict_cols, 
set_=update_fields)
-    elif dialect == "mysql":
-        from sqlalchemy.dialects.mysql import insert as mysql_insert
-
-        stmt = mysql_insert(model).values(**values)
-        stmt = stmt.on_duplicate_key_update(**update_fields)
-    else:
-        from sqlalchemy.dialects.sqlite import insert as sqlite_insert
-
-        stmt = sqlite_insert(model).values(**values)
-        stmt = stmt.on_conflict_do_update(index_elements=conflict_cols, 
set_=update_fields)
-    return stmt
-
-
 class Variable(Base, LoggingMixin):
     """A generic way to store and retrieve arbitrary content or settings as a 
simple key/value store."""
 
@@ -311,7 +281,7 @@ class Variable(Base, LoggingMixin):
                 is_encrypted=is_encrypted,
                 team_name=team_name,
             )
-            stmt = _build_variable_upsert_stmt(
+            stmt = build_upsert_stmt(
                 get_dialect_name(session), Variable, ["key"], upsert_values, 
update_fields
             )
             session.execute(stmt)
diff --git a/airflow-core/src/airflow/utils/sqlalchemy.py 
b/airflow-core/src/airflow/utils/sqlalchemy.py
index a767c65fad9..35a6f4ee05b 100644
--- a/airflow-core/src/airflow/utils/sqlalchemy.py
+++ b/airflow-core/src/airflow/utils/sqlalchemy.py
@@ -22,7 +22,7 @@ import copy
 import datetime
 import logging
 from collections.abc import Generator
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 
 from sqlalchemy import TIMESTAMP, PickleType, String, event, nullsfirst
 from sqlalchemy.dialects import mysql
@@ -39,6 +39,9 @@ if TYPE_CHECKING:
     from collections.abc import Iterable
 
     from kubernetes.client.models.v1_pod import V1Pod
+    from sqlalchemy.dialects.mysql.dml import Insert as MySQLInsert
+    from sqlalchemy.dialects.postgresql.dml import Insert as PostgreSQLInsert
+    from sqlalchemy.dialects.sqlite.dml import Insert as SQLiteInsert
     from sqlalchemy.exc import OperationalError
     from sqlalchemy.orm import Session
     from sqlalchemy.sql import Select
@@ -58,6 +61,51 @@ def get_dialect_name(session: Session) -> str | None:
     return getattr(bind.dialect, "name", None)
 
 
+def build_upsert_stmt(
+    dialect: str | None,
+    model: Any,
+    conflict_cols: list[str],
+    values: dict[str, Any],
+    update_fields: dict[str, Any],
+) -> MySQLInsert | PostgreSQLInsert | SQLiteInsert:
+    """
+    Build a dialect-specific ``INSERT ... ON CONFLICT DO UPDATE`` (upsert) 
statement.
+
+    A single-statement upsert is atomic at the database level, which avoids the
+    race conditions that arise from the non-atomic SELECT-then-INSERT 
performed by
+    ``session.merge()`` when concurrent transactions target the same primary 
key.
+
+    :param dialect: dialect name as returned by :func:`get_dialect_name`
+    :param model: the SQLAlchemy model (or table) to insert into
+    :param conflict_cols: columns that make up the conflict target 
(PostgreSQL/SQLite)
+    :param values: column values to insert
+    :param update_fields: column values to set when a conflicting row already 
exists
+    :raises ValueError: if the dialect does not support a known upsert syntax
+    """
+    stmt: MySQLInsert | PostgreSQLInsert | SQLiteInsert
+    if dialect == "postgresql":
+        from sqlalchemy.dialects.postgresql import insert as pg_insert
+
+        stmt = pg_insert(model).values(**values)
+        stmt = stmt.on_conflict_do_update(index_elements=conflict_cols, 
set_=update_fields)
+    elif dialect == "mysql":
+        from sqlalchemy.dialects.mysql import insert as mysql_insert
+
+        stmt = mysql_insert(model).values(**values)
+        stmt = stmt.on_duplicate_key_update(**update_fields)
+    elif dialect == "sqlite":
+        from sqlalchemy.dialects.sqlite import insert as sqlite_insert
+
+        stmt = sqlite_insert(model).values(**values)
+        stmt = stmt.on_conflict_do_update(index_elements=conflict_cols, 
set_=update_fields)
+    else:
+        raise ValueError(
+            f"Unsupported database dialect '{dialect}' for upsert. "
+            "Supported dialects are: postgresql, mysql, sqlite."
+        )
+    return stmt
+
+
 class random_db_uuid(FunctionElement):
     """
     Cross-dialect random UUID generation for use in SQL expressions.
diff --git a/airflow-core/tests/unit/models/test_renderedtifields.py 
b/airflow-core/tests/unit/models/test_renderedtifields.py
index 37e6088494d..f695c46aac2 100644
--- a/airflow-core/tests/unit/models/test_renderedtifields.py
+++ b/airflow-core/tests/unit/models/test_renderedtifields.py
@@ -27,7 +27,7 @@ from unittest import mock
 
 import pendulum
 import pytest
-from sqlalchemy import select
+from sqlalchemy import insert, select
 
 from airflow import settings
 from airflow._shared.template_rendering import truncate_rendered_value
@@ -372,6 +372,66 @@ class TestRenderedTaskInstanceFields:
             {"bash_command": "echo test_val_updated", "env": None, "cwd": 
None},
         )
 
+    def test_write_upsert_existing_record(self, dag_maker, session):
+        """
+        Verify that write() updates an existing row instead of failing on its 
primary key.
+
+        A row is seeded via a direct INSERT (bypassing write()) to represent a 
record
+        already present for this task instance. Calling write() with different 
values
+        must update that row via the upsert's DO UPDATE branch.
+
+        This exercises the upsert's update path within a single transaction; 
it does not
+        reproduce the concurrent-transaction race from #61705, which needs two 
separate
+        uncommitted transactions and cannot be triggered reliably in a unit 
test. The
+        atomic single-statement upsert is what closes that race in production.
+        """
+        with dag_maker("test_write_upsert", session=session):
+            task = BashOperator(task_id="test", bash_command="echo original")
+        dr = dag_maker.create_dagrun()
+        ti = dr.task_instances[0]
+        ti.task = task
+
+        # Seed the row via a direct INSERT to simulate a row already committed 
by
+        # the first request. Using write() here would mask whether write() 
itself
+        # correctly handles conflicts, since merge() also handles existing 
rows.
+        session.execute(
+            insert(RTIF).values(
+                dag_id=ti.dag_id,
+                task_id=ti.task_id,
+                run_id=ti.run_id,
+                map_index=ti.map_index,
+                rendered_fields={"bash_command": "echo original"},
+                k8s_pod_yaml=None,
+            )
+        )
+        session.flush()
+
+        result = session.scalar(
+            select(RTIF).where(
+                RTIF.dag_id == ti.dag_id,
+                RTIF.task_id == ti.task_id,
+                RTIF.run_id == ti.run_id,
+                RTIF.map_index == ti.map_index,
+            )
+        )
+        assert result.rendered_fields == {"bash_command": "echo original"}
+
+        # write() must not raise IntegrityError even though the row already 
exists.
+        rtif = RTIF(ti=ti, render_templates=False, 
rendered_fields={"bash_command": "echo updated"})
+        rtif.write(session=session)
+        session.flush()
+        session.expire_all()
+
+        result = session.scalar(
+            select(RTIF).where(
+                RTIF.dag_id == ti.dag_id,
+                RTIF.task_id == ti.task_id,
+                RTIF.run_id == ti.run_id,
+                RTIF.map_index == ti.map_index,
+            )
+        )
+        assert result.rendered_fields == {"bash_command": "echo updated"}
+
     @mock.patch.dict(os.environ, {"AIRFLOW_VAR_API_KEY": "secret"})
     def test_redact(self, dag_maker):
         with mock.patch("airflow._shared.secrets_masker.redact", 
autospec=True) as redact:

Reply via email to