ashb commented on a change in pull request #17581:
URL: https://github.com/apache/airflow/pull/17581#discussion_r687800528



##########
File path: tests/jobs/test_local_task_job.py
##########
@@ -273,42 +274,41 @@ def test_heartbeat_failed_fast(self):
                 delta = (time2 - time1).total_seconds()
                 assert abs(delta - job.heartrate) < 0.5
 
-    def test_mark_success_no_kill(self):
+    @patch.object(StandardTaskRunner, 'return_code')
+    def test_mark_success_no_kill(self, mock_return_code, caplog, dag_maker):
         """
         Test that ensures that mark_success in the UI doesn't cause
         the task to fail, and that the task exits
         """
-        dag = self.dagbag.dags.get('test_mark_success')
-        task = dag.get_task('task1')
-
         session = settings.Session()
 
-        dag.clear()
-        dag.create_dagrun(
-            run_id="test",
-            state=State.RUNNING,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            session=session,
-        )
-        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+        def task_function(ti):
+            assert ti.state == State.RUNNING
+            ti.state = State.SUCCESS
+            session.merge(ti)
+            session.commit()

Review comment:
       We probably need a comment here to say what we are "simulating".

##########
File path: tests/jobs/test_local_task_job.py
##########
@@ -273,42 +274,41 @@ def test_heartbeat_failed_fast(self):
                 delta = (time2 - time1).total_seconds()
                 assert abs(delta - job.heartrate) < 0.5
 
-    def test_mark_success_no_kill(self):
+    @patch.object(StandardTaskRunner, 'return_code')
+    def test_mark_success_no_kill(self, mock_return_code, caplog, dag_maker):
         """
         Test that ensures that mark_success in the UI doesn't cause
         the task to fail, and that the task exits
         """
-        dag = self.dagbag.dags.get('test_mark_success')
-        task = dag.get_task('task1')
-
         session = settings.Session()
 
-        dag.clear()
-        dag.create_dagrun(
-            run_id="test",
-            state=State.RUNNING,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            session=session,
-        )
-        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+        def task_function(ti):
+            assert ti.state == State.RUNNING
+            ti.state = State.SUCCESS
+            session.merge(ti)
+            session.commit()
+
+        with dag_maker('test_mark_success'):
+            task1 = PythonOperator(task_id="task1", 
python_callable=task_function)
+        dag_maker.create_dagrun()
+
+        ti = TaskInstance(task=task1, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
-        settings.engine.dispose()
-        process = multiprocessing.Process(target=job1.run)
-        process.start()
-        for _ in range(0, 50):
-            if ti.state == State.RUNNING:
-                break
-            time.sleep(0.1)
-            ti.refresh_from_db()
-        assert State.RUNNING == ti.state
-        ti.state = State.SUCCESS
-        session.merge(ti)
-        session.commit()
-        process.join(timeout=10)
+
+        # The return code when we mark success in the UI is None
+        def dummy_return_code(*args, **kwargs):
+            return None
+
+        mock_return_code.side_effect = dummy_return_code

Review comment:
       I think you can do this
   
   ```suggestion
           mock_return_code.return_value = None
   ```

##########
File path: tests/jobs/test_local_task_job.py
##########
@@ -587,23 +584,24 @@ def task_function(ti):
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, 
executor=SequentialExecutor())
+        job1.task_runner = StandardTaskRunner(job1)
+        job1.task_runner.start()
         settings.engine.dispose()
-        process = multiprocessing.Process(target=job1.run)
-        process.start()
-        time.sleep(0.3)
-        process.join(timeout=10)
+        caplog.set_level(logging.INFO)
+        job1.run()  # os.kill will make this run very short

Review comment:
       We should probably still put a timeout here just to be defensive




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to