This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 55e419e95ab Remove AIP-44 from Job (#44493) 55e419e95ab is described below commit 55e419e95ab027d161cef95571300af9b2c81a0d Author: Jarek Potiuk <ja...@potiuk.com> AuthorDate: Sat Nov 30 03:19:32 2024 +0100 Remove AIP-44 from Job (#44493) Part of #44436 --- airflow/jobs/job.py | 128 +++++---------------- .../providers/edge/worker_api/routes/rpc_api.py | 1 - tests/jobs/test_base_job.py | 9 +- 3 files changed, 32 insertions(+), 106 deletions(-) diff --git a/airflow/jobs/job.py b/airflow/jobs/job.py index 6e802372d83..75a075efdfc 100644 --- a/airflow/jobs/job.py +++ b/airflow/jobs/job.py @@ -26,13 +26,11 @@ from sqlalchemy.exc import OperationalError from sqlalchemy.orm import backref, foreign, relationship from sqlalchemy.orm.session import make_transient -from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.executors.executor_loader import ExecutorLoader from airflow.listeners.listener import get_listener_manager from airflow.models.base import ID_LEN, Base -from airflow.serialization.pydantic.job import JobPydantic from airflow.stats import Stats from airflow.traces.tracer import Trace, add_span from airflow.utils import timezone @@ -40,8 +38,7 @@ from airflow.utils.helpers import convert_camel_to_snake from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.platform import getuser -from airflow.utils.retries import retry_db_transaction -from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import UtcDateTime from airflow.utils.state import JobState @@ -168,7 +165,10 @@ class Job(Base, LoggingMixin): except Exception as e: self.log.error("on_kill() method failed: %s", e) - Job._kill(job_id=self.id, session=session) + job = session.scalar(select(Job).where(Job.id == self.id, session=session).limit(1)) + job.end_date = timezone.utcnow() + session.merge(job) + session.commit() raise AirflowException("Job shut down externally.") def on_kill(self): @@ -201,7 +201,7 @@ class Job(Base, LoggingMixin): try: span.set_attribute("heartbeat", str(self.latest_heartbeat)) # This will cause it to load from the db - self._merge_from(Job._fetch_from_db(self, session)) + session.merge(self) previous_heartbeat = self.latest_heartbeat if self.state == JobState.RESTARTING: @@ -217,17 +217,19 @@ class Job(Base, LoggingMixin): if span.is_recording(): span.add_event(name="sleep", attributes={"sleep_for": sleep_for}) sleep(sleep_for) - - job = Job._update_heartbeat(job=self, session=session) - self._merge_from(job) - time_since_last_heartbeat = (timezone.utcnow() - previous_heartbeat).total_seconds() - health_check_threshold_value = health_check_threshold(self.job_type, self.heartrate) - if time_since_last_heartbeat > health_check_threshold_value: - self.log.info("Heartbeat recovered after %.2f seconds", time_since_last_heartbeat) - # At this point, the DB has updated. - previous_heartbeat = self.latest_heartbeat - - heartbeat_callback(session) + # Update last heartbeat time + with create_session() as session: + # Make the session aware of this object + session.merge(self) + self.latest_heartbeat = timezone.utcnow() + session.commit() + time_since_last_heartbeat = (timezone.utcnow() - previous_heartbeat).total_seconds() + health_check_threshold_value = health_check_threshold(self.job_type, self.heartrate) + if time_since_last_heartbeat > health_check_threshold_value: + self.log.info("Heartbeat recovered after %.2f seconds", time_since_last_heartbeat) + # At this point, the DB has updated. + previous_heartbeat = self.latest_heartbeat + heartbeat_callback(session) self.log.debug("[heartbeat]") self.heartbeat_failed = False except OperationalError: @@ -260,36 +262,23 @@ class Job(Base, LoggingMixin): Stats.incr(self.__class__.__name__.lower() + "_start", 1, 1) self.state = JobState.RUNNING self.start_date = timezone.utcnow() - self._merge_from(Job._add_to_db(job=self, session=session)) + session.add(self) + session.commit() make_transient(self) @provide_session def complete_execution(self, session: Session = NEW_SESSION): get_listener_manager().hook.before_stopping(component=self) self.end_date = timezone.utcnow() - Job._update_in_db(job=self, session=session) + session.merge(self) + session.commit() Stats.incr(self.__class__.__name__.lower() + "_end", 1, 1) @provide_session - def most_recent_job(self, session: Session = NEW_SESSION) -> Job | JobPydantic | None: + def most_recent_job(self, session: Session = NEW_SESSION) -> Job | None: """Return the most recent job of this type, if any, based on last heartbeat received.""" return most_recent_job(self.job_type, session=session) - def _merge_from(self, job: Job | JobPydantic | None): - if job is None: - self.log.error("Job is empty: %s", self.id) - return - self.id = job.id - self.dag_id = job.dag_id - self.state = job.state - self.job_type = job.job_type - self.start_date = job.start_date - self.end_date = job.end_date - self.latest_heartbeat = job.latest_heartbeat - self.executor_class = job.executor_class - self.hostname = job.hostname - self.unixname = job.unixname - @staticmethod def _heartrate(job_type: str) -> float: if job_type == "TriggererJob": @@ -312,74 +301,9 @@ class Job(Base, LoggingMixin): and (timezone.utcnow() - latest_heartbeat).total_seconds() < health_check_threshold_value ) - @staticmethod - @internal_api_call - @provide_session - def _kill(job_id: str, session: Session = NEW_SESSION) -> Job | JobPydantic: - job = session.scalar(select(Job).where(Job.id == job_id).limit(1)) - job.end_date = timezone.utcnow() - session.merge(job) - session.commit() - return job - - @staticmethod - @internal_api_call - @provide_session - @retry_db_transaction - def _fetch_from_db(job: Job | JobPydantic, session: Session = NEW_SESSION) -> Job | JobPydantic | None: - if isinstance(job, Job): - # not Internal API - session.merge(job) - return job - # Internal API, - return session.scalar(select(Job).where(Job.id == job.id).limit(1)) - - @staticmethod - @internal_api_call - @provide_session - def _add_to_db(job: Job | JobPydantic, session: Session = NEW_SESSION) -> Job | JobPydantic: - if isinstance(job, JobPydantic): - orm_job = Job() - orm_job._merge_from(job) - else: - orm_job = job - session.add(orm_job) - session.commit() - return orm_job - - @staticmethod - @internal_api_call - @provide_session - def _update_in_db(job: Job | JobPydantic, session: Session = NEW_SESSION): - if isinstance(job, Job): - # not Internal API - session.merge(job) - session.commit() - # Internal API. - orm_job: Job | None = session.scalar(select(Job).where(Job.id == job.id).limit(1)) - if orm_job is None: - return - orm_job._merge_from(job) - session.merge(orm_job) - session.commit() - - @staticmethod - @internal_api_call - @provide_session - @retry_db_transaction - def _update_heartbeat(job: Job | JobPydantic, session: Session = NEW_SESSION) -> Job | JobPydantic: - orm_job: Job | None = session.scalar(select(Job).where(Job.id == job.id).limit(1)) - if orm_job is None: - return job - orm_job.latest_heartbeat = timezone.utcnow() - session.merge(orm_job) - session.commit() - return orm_job - -@internal_api_call @provide_session -def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job | JobPydantic | None: +def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job | None: """ Return the most recent job of this type, if any, based on last heartbeat received. @@ -434,7 +358,7 @@ def execute_job(job: Job, execute_callable: Callable[[], int | None]) -> int | N which happens in the "complete_execution" step (which again can be executed locally in case of database operations or over the Internal API call. - :param job: Job to execute - it can be either DB job or it's Pydantic serialized version. It does + :param job: Job to execute - DB job. It does not really matter, because except of running the heartbeat and state setting, the runner should not modify the job state. diff --git a/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py b/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py index b3ceaa68700..aa5b30f5ab7 100644 --- a/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py +++ b/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py @@ -119,7 +119,6 @@ def _initialize_method_map() -> dict[str, Callable]: expand_alias_to_assets, FileTaskHandler._render_filename_db_access, Job._add_to_db, - Job._fetch_from_db, Job._kill, Job._update_heartbeat, Job._update_in_db, diff --git a/tests/jobs/test_base_job.py b/tests/jobs/test_base_job.py index 4d5f3787ba6..63f5ef1910a 100644 --- a/tests/jobs/test_base_job.py +++ b/tests/jobs/test_base_job.py @@ -228,10 +228,13 @@ class TestJob: job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=10) assert job.is_alive() is False, "Completed jobs even with recent heartbeat should not be alive" - def test_heartbeat_failed(self, caplog): + @patch("airflow.jobs.job.create_session") + def test_heartbeat_failed(self, mock_create_session, caplog): when = timezone.utcnow() - datetime.timedelta(seconds=60) - mock_session = Mock(name="MockSession") - mock_session.commit.side_effect = OperationalError("Force fail", {}, None) + with create_session() as session: + mock_session = Mock(spec_set=session, name="MockSession") + mock_create_session.return_value.__enter__.return_value = mock_session + mock_session.commit.side_effect = OperationalError("Force fail", {}, None) job = Job(heartrate=10, state=State.RUNNING) job.latest_heartbeat = when with caplog.at_level(logging.ERROR):