This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 55e419e95ab Remove AIP-44 from Job (#44493)
55e419e95ab is described below

commit 55e419e95ab027d161cef95571300af9b2c81a0d
Author: Jarek Potiuk <ja...@potiuk.com>
AuthorDate: Sat Nov 30 03:19:32 2024 +0100

    Remove AIP-44 from Job (#44493)
    
    Part of #44436
---
 airflow/jobs/job.py                                | 128 +++++----------------
 .../providers/edge/worker_api/routes/rpc_api.py    |   1 -
 tests/jobs/test_base_job.py                        |   9 +-
 3 files changed, 32 insertions(+), 106 deletions(-)

diff --git a/airflow/jobs/job.py b/airflow/jobs/job.py
index 6e802372d83..75a075efdfc 100644
--- a/airflow/jobs/job.py
+++ b/airflow/jobs/job.py
@@ -26,13 +26,11 @@ from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm import backref, foreign, relationship
 from sqlalchemy.orm.session import make_transient
 
-from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.listeners.listener import get_listener_manager
 from airflow.models.base import ID_LEN, Base
-from airflow.serialization.pydantic.job import JobPydantic
 from airflow.stats import Stats
 from airflow.traces.tracer import Trace, add_span
 from airflow.utils import timezone
@@ -40,8 +38,7 @@ from airflow.utils.helpers import convert_camel_to_snake
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.net import get_hostname
 from airflow.utils.platform import getuser
-from airflow.utils.retries import retry_db_transaction
-from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
 from airflow.utils.state import JobState
 
@@ -168,7 +165,10 @@ class Job(Base, LoggingMixin):
         except Exception as e:
             self.log.error("on_kill() method failed: %s", e)
 
-        Job._kill(job_id=self.id, session=session)
+        job = session.scalar(select(Job).where(Job.id == self.id, 
session=session).limit(1))
+        job.end_date = timezone.utcnow()
+        session.merge(job)
+        session.commit()
         raise AirflowException("Job shut down externally.")
 
     def on_kill(self):
@@ -201,7 +201,7 @@ class Job(Base, LoggingMixin):
             try:
                 span.set_attribute("heartbeat", str(self.latest_heartbeat))
                 # This will cause it to load from the db
-                self._merge_from(Job._fetch_from_db(self, session))
+                session.merge(self)
                 previous_heartbeat = self.latest_heartbeat
 
                 if self.state == JobState.RESTARTING:
@@ -217,17 +217,19 @@ class Job(Base, LoggingMixin):
                 if span.is_recording():
                     span.add_event(name="sleep", attributes={"sleep_for": 
sleep_for})
                 sleep(sleep_for)
-
-                job = Job._update_heartbeat(job=self, session=session)
-                self._merge_from(job)
-                time_since_last_heartbeat = (timezone.utcnow() - 
previous_heartbeat).total_seconds()
-                health_check_threshold_value = 
health_check_threshold(self.job_type, self.heartrate)
-                if time_since_last_heartbeat > health_check_threshold_value:
-                    self.log.info("Heartbeat recovered after %.2f seconds", 
time_since_last_heartbeat)
-                # At this point, the DB has updated.
-                previous_heartbeat = self.latest_heartbeat
-
-                heartbeat_callback(session)
+                # Update last heartbeat time
+                with create_session() as session:
+                    # Make the session aware of this object
+                    session.merge(self)
+                    self.latest_heartbeat = timezone.utcnow()
+                    session.commit()
+                    time_since_last_heartbeat = (timezone.utcnow() - 
previous_heartbeat).total_seconds()
+                    health_check_threshold_value = 
health_check_threshold(self.job_type, self.heartrate)
+                    if time_since_last_heartbeat > 
health_check_threshold_value:
+                        self.log.info("Heartbeat recovered after %.2f 
seconds", time_since_last_heartbeat)
+                    # At this point, the DB has updated.
+                    previous_heartbeat = self.latest_heartbeat
+                    heartbeat_callback(session)
                 self.log.debug("[heartbeat]")
                 self.heartbeat_failed = False
             except OperationalError:
@@ -260,36 +262,23 @@ class Job(Base, LoggingMixin):
         Stats.incr(self.__class__.__name__.lower() + "_start", 1, 1)
         self.state = JobState.RUNNING
         self.start_date = timezone.utcnow()
-        self._merge_from(Job._add_to_db(job=self, session=session))
+        session.add(self)
+        session.commit()
         make_transient(self)
 
     @provide_session
     def complete_execution(self, session: Session = NEW_SESSION):
         get_listener_manager().hook.before_stopping(component=self)
         self.end_date = timezone.utcnow()
-        Job._update_in_db(job=self, session=session)
+        session.merge(self)
+        session.commit()
         Stats.incr(self.__class__.__name__.lower() + "_end", 1, 1)
 
     @provide_session
-    def most_recent_job(self, session: Session = NEW_SESSION) -> Job | 
JobPydantic | None:
+    def most_recent_job(self, session: Session = NEW_SESSION) -> Job | None:
         """Return the most recent job of this type, if any, based on last 
heartbeat received."""
         return most_recent_job(self.job_type, session=session)
 
-    def _merge_from(self, job: Job | JobPydantic | None):
-        if job is None:
-            self.log.error("Job is empty: %s", self.id)
-            return
-        self.id = job.id
-        self.dag_id = job.dag_id
-        self.state = job.state
-        self.job_type = job.job_type
-        self.start_date = job.start_date
-        self.end_date = job.end_date
-        self.latest_heartbeat = job.latest_heartbeat
-        self.executor_class = job.executor_class
-        self.hostname = job.hostname
-        self.unixname = job.unixname
-
     @staticmethod
     def _heartrate(job_type: str) -> float:
         if job_type == "TriggererJob":
@@ -312,74 +301,9 @@ class Job(Base, LoggingMixin):
             and (timezone.utcnow() - latest_heartbeat).total_seconds() < 
health_check_threshold_value
         )
 
-    @staticmethod
-    @internal_api_call
-    @provide_session
-    def _kill(job_id: str, session: Session = NEW_SESSION) -> Job | 
JobPydantic:
-        job = session.scalar(select(Job).where(Job.id == job_id).limit(1))
-        job.end_date = timezone.utcnow()
-        session.merge(job)
-        session.commit()
-        return job
-
-    @staticmethod
-    @internal_api_call
-    @provide_session
-    @retry_db_transaction
-    def _fetch_from_db(job: Job | JobPydantic, session: Session = NEW_SESSION) 
-> Job | JobPydantic | None:
-        if isinstance(job, Job):
-            # not Internal API
-            session.merge(job)
-            return job
-        # Internal API,
-        return session.scalar(select(Job).where(Job.id == job.id).limit(1))
-
-    @staticmethod
-    @internal_api_call
-    @provide_session
-    def _add_to_db(job: Job | JobPydantic, session: Session = NEW_SESSION) -> 
Job | JobPydantic:
-        if isinstance(job, JobPydantic):
-            orm_job = Job()
-            orm_job._merge_from(job)
-        else:
-            orm_job = job
-        session.add(orm_job)
-        session.commit()
-        return orm_job
-
-    @staticmethod
-    @internal_api_call
-    @provide_session
-    def _update_in_db(job: Job | JobPydantic, session: Session = NEW_SESSION):
-        if isinstance(job, Job):
-            # not Internal API
-            session.merge(job)
-            session.commit()
-        # Internal API.
-        orm_job: Job | None = session.scalar(select(Job).where(Job.id == 
job.id).limit(1))
-        if orm_job is None:
-            return
-        orm_job._merge_from(job)
-        session.merge(orm_job)
-        session.commit()
-
-    @staticmethod
-    @internal_api_call
-    @provide_session
-    @retry_db_transaction
-    def _update_heartbeat(job: Job | JobPydantic, session: Session = 
NEW_SESSION) -> Job | JobPydantic:
-        orm_job: Job | None = session.scalar(select(Job).where(Job.id == 
job.id).limit(1))
-        if orm_job is None:
-            return job
-        orm_job.latest_heartbeat = timezone.utcnow()
-        session.merge(orm_job)
-        session.commit()
-        return orm_job
-
 
-@internal_api_call
 @provide_session
-def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job | 
JobPydantic | None:
+def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job | 
None:
     """
     Return the most recent job of this type, if any, based on last heartbeat 
received.
 
@@ -434,7 +358,7 @@ def execute_job(job: Job, execute_callable: Callable[[], 
int | None]) -> int | N
     which happens in the "complete_execution" step (which again can be 
executed locally in case of
     database operations or over the Internal API call.
 
-    :param job: Job to execute - it can be either DB job or it's Pydantic 
serialized version. It does
+    :param job: Job to execute - DB job. It does
       not really matter, because except of running the heartbeat and state 
setting,
       the runner should not modify the job state.
 
diff --git a/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py 
b/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py
index b3ceaa68700..aa5b30f5ab7 100644
--- a/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py
+++ b/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py
@@ -119,7 +119,6 @@ def _initialize_method_map() -> dict[str, Callable]:
         expand_alias_to_assets,
         FileTaskHandler._render_filename_db_access,
         Job._add_to_db,
-        Job._fetch_from_db,
         Job._kill,
         Job._update_heartbeat,
         Job._update_in_db,
diff --git a/tests/jobs/test_base_job.py b/tests/jobs/test_base_job.py
index 4d5f3787ba6..63f5ef1910a 100644
--- a/tests/jobs/test_base_job.py
+++ b/tests/jobs/test_base_job.py
@@ -228,10 +228,13 @@ class TestJob:
         job.latest_heartbeat = timezone.utcnow() - 
datetime.timedelta(seconds=10)
         assert job.is_alive() is False, "Completed jobs even with recent 
heartbeat should not be alive"
 
-    def test_heartbeat_failed(self, caplog):
+    @patch("airflow.jobs.job.create_session")
+    def test_heartbeat_failed(self, mock_create_session, caplog):
         when = timezone.utcnow() - datetime.timedelta(seconds=60)
-        mock_session = Mock(name="MockSession")
-        mock_session.commit.side_effect = OperationalError("Force fail", {}, 
None)
+        with create_session() as session:
+            mock_session = Mock(spec_set=session, name="MockSession")
+            mock_create_session.return_value.__enter__.return_value = 
mock_session
+            mock_session.commit.side_effect = OperationalError("Force fail", 
{}, None)
         job = Job(heartrate=10, state=State.RUNNING)
         job.latest_heartbeat = when
         with caplog.at_level(logging.ERROR):

Reply via email to