jedcunningham commented on code in PR #42913:
URL: https://github.com/apache/airflow/pull/42913#discussion_r1823411590
##########
airflow/models/dag.py:
##########
@@ -2461,7 +2469,7 @@ def create_dagrun(
:param conf: Dict containing configuration/parameters to pass to the
DAG
:param creating_job_id: id of the job creating this DagRun
:param session: database session
- :param dag_hash: Hash of Serialized DAG
+ :param dag_version_id: The DagVersion ID to run with
Review Comment:
```suggestion
:param dag_version_id: The DagVersion ID for this run
```
Minor thing, but since this is observed, it's probably better to phrase it
this way. The bundle version is where we will attempt to hold this stable.
##########
airflow/models/dag_version.py:
##########
@@ -0,0 +1,163 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+import random
+import string
+from typing import TYPE_CHECKING
+
+import uuid6
+from sqlalchemy import Column, ForeignKey, Integer, UniqueConstraint, func,
select
+from sqlalchemy.orm import relationship
+from sqlalchemy_utils import UUIDType
+
+from airflow.models.base import Base, StringID
+from airflow.utils import timezone
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+
+log = logging.getLogger(__name__)
+
+
+def _gen_random_str():
+ return "".join(random.choices(string.ascii_letters + string.digits, k=10))
+
+
+class DagVersion(Base):
+ """Model to track the versions of DAGs in the database."""
+
+ __tablename__ = "dag_version"
+ id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7)
+ version_number = Column(Integer, nullable=False, default=1)
+ version_name = Column(StringID(), default=_gen_random_str, nullable=False)
+ dag_id = Column(StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"),
nullable=False)
+ dag_model = relationship("DagModel", back_populates="dag_versions")
+ dag_code = relationship(
+ "DagCode",
+ back_populates="dag_version",
+ uselist=False,
+ cascade="all, delete, delete-orphan",
+ cascade_backrefs=False,
+ )
+ serialized_dag = relationship(
+ "SerializedDagModel",
+ back_populates="dag_version",
+ uselist=False,
+ cascade="all, delete, delete-orphan",
+ cascade_backrefs=False,
+ )
+ dag_runs = relationship("DagRun", back_populates="dag_version",
cascade="all, delete, delete-orphan")
+ task_instances = relationship("TaskInstance", back_populates="dag_version")
+ created_at = Column(UtcDateTime, default=timezone.utcnow)
+
+ __table_args__ = (
+ UniqueConstraint("dag_id", "version_number",
name="dag_id_version_number_unique_constraint"),
+ )
+
+ def __repr__(self):
+ return f"<DagVersion {self.dag_id} - {self.version_name}>"
Review Comment:
```suggestion
return f"<DagVersion {self.dag_id} {self.version}>"
```
Aren't we missing half of the logical key here?
##########
airflow/models/dag_version.py:
##########
@@ -0,0 +1,163 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+import random
+import string
+from typing import TYPE_CHECKING
+
+import uuid6
+from sqlalchemy import Column, ForeignKey, Integer, UniqueConstraint, func,
select
+from sqlalchemy.orm import relationship
+from sqlalchemy_utils import UUIDType
+
+from airflow.models.base import Base, StringID
+from airflow.utils import timezone
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+
+log = logging.getLogger(__name__)
+
+
+def _gen_random_str():
+ return "".join(random.choices(string.ascii_letters + string.digits, k=10))
+
+
+class DagVersion(Base):
+ """Model to track the versions of DAGs in the database."""
+
+ __tablename__ = "dag_version"
+ id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7)
+ version_number = Column(Integer, nullable=False, default=1)
+ version_name = Column(StringID(), default=_gen_random_str, nullable=False)
+ dag_id = Column(StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"),
nullable=False)
+ dag_model = relationship("DagModel", back_populates="dag_versions")
+ dag_code = relationship(
+ "DagCode",
+ back_populates="dag_version",
+ uselist=False,
+ cascade="all, delete, delete-orphan",
+ cascade_backrefs=False,
+ )
+ serialized_dag = relationship(
+ "SerializedDagModel",
+ back_populates="dag_version",
+ uselist=False,
+ cascade="all, delete, delete-orphan",
+ cascade_backrefs=False,
+ )
+ dag_runs = relationship("DagRun", back_populates="dag_version",
cascade="all, delete, delete-orphan")
+ task_instances = relationship("TaskInstance", back_populates="dag_version")
+ created_at = Column(UtcDateTime, default=timezone.utcnow)
+
+ __table_args__ = (
+ UniqueConstraint("dag_id", "version_number",
name="dag_id_version_number_unique_constraint"),
+ )
+
+ def __repr__(self):
+ return f"<DagVersion {self.dag_id} - {self.version_name}>"
+
+ @classmethod
+ @provide_session
+ def write_dag(
+ cls,
+ *,
+ dag_id: str,
+ version_name: str | None = None,
+ version_number: int = 1,
+ session: Session = NEW_SESSION,
+ ):
+ """Write a new DagVersion into database."""
+ existing_dag_version = session.scalar(
+ with_row_locks(cls._latest_version_select(dag_id), of=DagVersion,
session=session, nowait=True)
+ )
+ if existing_dag_version:
+ version_number = existing_dag_version.version_number + 1
+ if existing_dag_version and not version_name:
+ version_name = existing_dag_version.version_name
+
+ dag_version = DagVersion(
+ dag_id=dag_id,
+ version_number=version_number,
+ version_name=version_name,
+ )
+ log.debug("Writing DagVersion %s to the DB", dag_version)
+ session.add(dag_version)
+ # Flush is necessary here due to the unique constraint and other
linked tables
+ session.flush()
+ log.debug("DagVersion %s written to the DB", dag_version)
+ return dag_version
+
+ @classmethod
+ def _latest_version_select(cls, dag_id: str):
+ return select(cls).where(cls.dag_id ==
dag_id).order_by(cls.created_at.desc()).limit(1)
+
+ @classmethod
+ @provide_session
+ def get_latest_version(cls, dag_id: str, session: Session = NEW_SESSION):
+ return session.scalar(cls._latest_version_select(dag_id))
+
+ @classmethod
+ @provide_session
+ def get_version(
+ cls,
+ dag_id: str,
+ version_name: str | None = None,
Review Comment:
I'm slightly torn on this. It's not necessarily unique. And scalar just
returns the "first" one - it may not be the version you are really after. Feels
like this should mirror the logical key (dag_id/version_number).
##########
providers/tests/fab/auth_manager/api_endpoints/test_backfill_endpoint.py:
##########
@@ -188,7 +188,7 @@ class TestCreateBackfill(TestBackfillEndpoint):
def test_create_backfill(self, session, dag_maker):
with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="0 * * *
*") as dag:
EmptyOperator(task_id="mytask")
- session.add(SerializedDagModel(dag))
+ SerializedDagModel.write_dag(dag)
Review Comment:
```suggestion
SerializedDagModel.write_dag(dag, session=session)
```
Probably should use the same session?
##########
tests/jobs/test_scheduler_job.py:
##########
@@ -3384,57 +3381,58 @@ def test_verify_integrity_if_dag_changed(self,
dag_maker):
assert tis_count == 2
latest_dag_version =
SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session)
- assert dr.dag_hash == latest_dag_version
+ assert dr.dag_version.serialized_dag.dag_hash == latest_dag_version
session.rollback()
session.close()
- def test_verify_integrity_if_dag_disappeared(self, dag_maker, caplog):
- # CleanUp
- with create_session() as session:
- session.query(SerializedDagModel).filter(
- SerializedDagModel.dag_id ==
"test_verify_integrity_if_dag_disappeared"
- ).delete(synchronize_session=False)
-
- with dag_maker(dag_id="test_verify_integrity_if_dag_disappeared") as
dag:
- BashOperator(task_id="dummy", bash_command="echo hi")
-
- scheduler_job = Job()
- self.job_runner = SchedulerJobRunner(job=scheduler_job,
subdir=os.devnull)
-
- session = settings.Session()
- orm_dag = dag_maker.dag_model
- assert orm_dag is not None
-
- scheduler_job = Job()
- self.job_runner = SchedulerJobRunner(job=scheduler_job,
subdir=os.devnull)
-
- self.job_runner.processor_agent = mock.MagicMock()
- dag =
self.job_runner.dagbag.get_dag("test_verify_integrity_if_dag_disappeared",
session=session)
- self.job_runner._create_dag_runs([orm_dag], session)
- dag_id = dag.dag_id
- drs = DagRun.find(dag_id=dag_id, session=session)
- assert len(drs) == 1
- dr = drs[0]
-
- dag_version_1 = SerializedDagModel.get_latest_version_hash(dag_id,
session=session)
- assert dr.dag_hash == dag_version_1
- assert self.job_runner.dagbag.dags ==
{"test_verify_integrity_if_dag_disappeared": dag}
- assert
len(self.job_runner.dagbag.dags.get("test_verify_integrity_if_dag_disappeared").tasks)
== 1
-
- SerializedDagModel.remove_dag(dag_id=dag_id)
- dag = self.job_runner.dagbag.dags[dag_id]
- self.job_runner.dagbag.dags = MagicMock()
- self.job_runner.dagbag.dags.get.side_effect = [dag, None]
- session.flush()
- with caplog.at_level(logging.WARNING):
- callback = self.job_runner._schedule_dag_run(dr, session)
- assert "The DAG disappeared before verifying integrity" in
caplog.text
-
- assert callback is None
-
- session.rollback()
- session.close()
+ # def test_verify_integrity_if_dag_disappeared(self, dag_maker, caplog):
Review Comment:
Should do something about this, shouldn't leave it commented out.
##########
airflow/models/taskinstance.py:
##########
@@ -1277,6 +1280,8 @@ def _refresh_from_task(
:meta private:
"""
task_instance.task = task
+ print(task_instance)
+ print(task)
Review Comment:
```suggestion
```
##########
airflow/models/dag_version.py:
##########
@@ -0,0 +1,163 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+import random
+import string
+from typing import TYPE_CHECKING
+
+import uuid6
+from sqlalchemy import Column, ForeignKey, Integer, UniqueConstraint, func,
select
+from sqlalchemy.orm import relationship
+from sqlalchemy_utils import UUIDType
+
+from airflow.models.base import Base, StringID
+from airflow.utils import timezone
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+
+log = logging.getLogger(__name__)
+
+
+def _gen_random_str():
+ return "".join(random.choices(string.ascii_letters + string.digits, k=10))
+
+
+class DagVersion(Base):
+ """Model to track the versions of DAGs in the database."""
+
+ __tablename__ = "dag_version"
+ id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7)
+ version_number = Column(Integer, nullable=False, default=1)
+ version_name = Column(StringID(), default=_gen_random_str, nullable=False)
+ dag_id = Column(StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"),
nullable=False)
+ dag_model = relationship("DagModel", back_populates="dag_versions")
+ dag_code = relationship(
+ "DagCode",
+ back_populates="dag_version",
+ uselist=False,
+ cascade="all, delete, delete-orphan",
+ cascade_backrefs=False,
+ )
+ serialized_dag = relationship(
+ "SerializedDagModel",
+ back_populates="dag_version",
+ uselist=False,
+ cascade="all, delete, delete-orphan",
+ cascade_backrefs=False,
+ )
+ dag_runs = relationship("DagRun", back_populates="dag_version",
cascade="all, delete, delete-orphan")
+ task_instances = relationship("TaskInstance", back_populates="dag_version")
+ created_at = Column(UtcDateTime, default=timezone.utcnow)
+
+ __table_args__ = (
+ UniqueConstraint("dag_id", "version_number",
name="dag_id_version_number_unique_constraint"),
+ )
+
+ def __repr__(self):
+ return f"<DagVersion {self.dag_id} - {self.version_name}>"
+
+ @classmethod
+ @provide_session
+ def write_dag(
+ cls,
+ *,
+ dag_id: str,
+ version_name: str | None = None,
+ version_number: int = 1,
+ session: Session = NEW_SESSION,
+ ):
+ """Write a new DagVersion into database."""
+ existing_dag_version = session.scalar(
+ with_row_locks(cls._latest_version_select(dag_id), of=DagVersion,
session=session, nowait=True)
+ )
+ if existing_dag_version:
+ version_number = existing_dag_version.version_number + 1
+ if existing_dag_version and not version_name:
+ version_name = existing_dag_version.version_name
Review Comment:
Hmm, why wouldn't we want to honor folks removing the `version_name`?
##########
airflow/models/dag_version.py:
##########
@@ -0,0 +1,163 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+import random
+import string
+from typing import TYPE_CHECKING
+
+import uuid6
+from sqlalchemy import Column, ForeignKey, Integer, UniqueConstraint, func,
select
+from sqlalchemy.orm import relationship
+from sqlalchemy_utils import UUIDType
+
+from airflow.models.base import Base, StringID
+from airflow.utils import timezone
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+
+log = logging.getLogger(__name__)
+
+
+def _gen_random_str():
+ return "".join(random.choices(string.ascii_letters + string.digits, k=10))
+
+
+class DagVersion(Base):
+ """Model to track the versions of DAGs in the database."""
+
+ __tablename__ = "dag_version"
+ id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7)
+ version_number = Column(Integer, nullable=False, default=1)
+ version_name = Column(StringID(), default=_gen_random_str, nullable=False)
+ dag_id = Column(StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"),
nullable=False)
+ dag_model = relationship("DagModel", back_populates="dag_versions")
+ dag_code = relationship(
+ "DagCode",
+ back_populates="dag_version",
+ uselist=False,
+ cascade="all, delete, delete-orphan",
+ cascade_backrefs=False,
+ )
+ serialized_dag = relationship(
+ "SerializedDagModel",
+ back_populates="dag_version",
+ uselist=False,
+ cascade="all, delete, delete-orphan",
+ cascade_backrefs=False,
+ )
+ dag_runs = relationship("DagRun", back_populates="dag_version",
cascade="all, delete, delete-orphan")
+ task_instances = relationship("TaskInstance", back_populates="dag_version")
+ created_at = Column(UtcDateTime, default=timezone.utcnow)
+
+ __table_args__ = (
+ UniqueConstraint("dag_id", "version_number",
name="dag_id_version_number_unique_constraint"),
+ )
+
+ def __repr__(self):
+ return f"<DagVersion {self.dag_id} - {self.version_name}>"
+
+ @classmethod
+ @provide_session
+ def write_dag(
+ cls,
+ *,
+ dag_id: str,
+ version_name: str | None = None,
+ version_number: int = 1,
+ session: Session = NEW_SESSION,
+ ):
+ """Write a new DagVersion into database."""
+ existing_dag_version = session.scalar(
+ with_row_locks(cls._latest_version_select(dag_id), of=DagVersion,
session=session, nowait=True)
+ )
+ if existing_dag_version:
+ version_number = existing_dag_version.version_number + 1
+ if existing_dag_version and not version_name:
+ version_name = existing_dag_version.version_name
+
+ dag_version = DagVersion(
+ dag_id=dag_id,
+ version_number=version_number,
+ version_name=version_name,
+ )
+ log.debug("Writing DagVersion %s to the DB", dag_version)
+ session.add(dag_version)
+ # Flush is necessary here due to the unique constraint and other
linked tables
+ session.flush()
+ log.debug("DagVersion %s written to the DB", dag_version)
+ return dag_version
+
+ @classmethod
+ def _latest_version_select(cls, dag_id: str):
+ return select(cls).where(cls.dag_id ==
dag_id).order_by(cls.created_at.desc()).limit(1)
+
+ @classmethod
+ @provide_session
+ def get_latest_version(cls, dag_id: str, session: Session = NEW_SESSION):
+ return session.scalar(cls._latest_version_select(dag_id))
+
+ @classmethod
+ @provide_session
+ def get_version(
+ cls,
+ dag_id: str,
+ version_name: str | None = None,
+ version_number: int | None = None,
+ session: Session = NEW_SESSION,
+ ):
+ version_select_obj = select(cls).where(cls.dag_id == dag_id)
+ if version_name:
+ version_select_obj = version_select_obj.where(cls.version_name ==
version_name)
+ if version_number:
+ version_select_obj = version_select_obj.where(cls.version_number
== version_number)
+ version_select_obj =
version_select_obj.order_by(cls.version_number.desc()).limit(1)
+ return session.scalar(version_select_obj)
+
+ @property
+ def version(self):
+ if not self.version_name and not self.version_number:
+ return None
Review Comment:
```suggestion
return str(self.version_number)
```
If the optional `version_name` isn't set, shouldn't we still return
something?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]