This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 68adc0e059 Refactor commands to unify daemon context handling (#34945)
68adc0e059 is described below

commit 68adc0e059ac65f20dfc7cf0038edb96b1244d32
Author: Daniel DylÄ…g <bi...@users.noreply.github.com>
AuthorDate: Tue Oct 24 10:16:47 2023 +0200

    Refactor commands to unify daemon context handling (#34945)
---
 .pre-commit-config.yaml                       |  1 +
 airflow/cli/commands/celery_command.py        | 89 +++++++++----------------
 airflow/cli/commands/daemon_utils.py          | 82 +++++++++++++++++++++++
 airflow/cli/commands/dag_processor_command.py | 32 ++-------
 airflow/cli/commands/internal_api_command.py  | 95 ++++++++++++---------------
 airflow/cli/commands/kerberos_command.py      | 29 ++------
 airflow/cli/commands/scheduler_command.py     | 50 +++++---------
 airflow/cli/commands/triggerer_command.py     | 48 ++++----------
 airflow/cli/commands/webserver_command.py     | 93 +++++++++++---------------
 tests/cli/commands/test_celery_command.py     | 18 ++---
 tests/cli/commands/test_kerberos_command.py   | 29 ++++----
 11 files changed, 256 insertions(+), 310 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 85cfc050c1..d7c10d53ff 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -509,6 +509,7 @@ repos:
           ^airflow/api_connexion/openapi/v1.yaml$|
           ^airflow/auth/managers/fab/security_manager/|
           ^airflow/cli/commands/webserver_command.py$|
+          ^airflow/cli/commands/internal_api_command.py$|
           ^airflow/config_templates/|
           ^airflow/models/baseoperator.py$|
           ^airflow/operators/__init__.py$|
diff --git a/airflow/cli/commands/celery_command.py 
b/airflow/cli/commands/celery_command.py
index eb53d6f60d..5e3e01042a 100644
--- a/airflow/cli/commands/celery_command.py
+++ b/airflow/cli/commands/celery_command.py
@@ -23,19 +23,18 @@ import sys
 from contextlib import contextmanager
 from multiprocessing import Process
 
-import daemon
 import psutil
 import sqlalchemy.exc
 from celery import maybe_patch_concurrency  # type: ignore[attr-defined]
 from celery.app.defaults import DEFAULT_TASK_LOG_FMT
 from celery.signals import after_setup_logger
-from daemon.pidfile import TimeoutPIDLockFile
 from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile
 
 from airflow import settings
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
 from airflow.configuration import conf
 from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations, setup_logging
+from airflow.utils.cli import setup_locations
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
 from airflow.utils.serve_logs import serve_logs
 
@@ -68,28 +67,9 @@ def flower(args):
     if args.flower_conf:
         options.append(f"--conf={args.flower_conf}")
 
-    if args.daemon:
-        pidfile, stdout, stderr, _ = setup_locations(
-            process="flower",
-            pid=args.pid,
-            stdout=args.stdout,
-            stderr=args.stderr,
-            log=args.log_file,
-        )
-        with open(stdout, "a") as stdout, open(stderr, "a") as stderr:
-            stdout.truncate(0)
-            stderr.truncate(0)
-
-            ctx = daemon.DaemonContext(
-                pidfile=TimeoutPIDLockFile(pidfile, -1),
-                stdout=stdout,
-                stderr=stderr,
-                umask=int(settings.DAEMON_UMASK, 8),
-            )
-            with ctx:
-                celery_app.start(options)
-    else:
-        celery_app.start(options)
+    run_command_with_daemon_option(
+        args=args, process_name="flower", callback=lambda: 
celery_app.start(options)
+    )
 
 
 @contextmanager
@@ -152,15 +132,6 @@ def worker(args):
     if autoscale is None and conf.has_option("celery", "worker_autoscale"):
         autoscale = conf.get("celery", "worker_autoscale")
 
-    # Setup locations
-    pid_file_path, stdout, stderr, log_file = setup_locations(
-        process=WORKER_PROCESS_NAME,
-        pid=args.pid,
-        stdout=args.stdout,
-        stderr=args.stderr,
-        log=args.log_file,
-    )
-
     if hasattr(celery_app.backend, "ResultSession"):
         # Pre-create the database tables now, otherwise SQLA via Celery has a
         # race condition where one of the subprocesses can die with "Table
@@ -181,6 +152,10 @@ def worker(args):
     celery_log_level = conf.get("logging", "CELERY_LOGGING_LEVEL")
     if not celery_log_level:
         celery_log_level = conf.get("logging", "LOGGING_LEVEL")
+
+    # Setup pid file location
+    worker_pid_file_path, _, _, _ = 
setup_locations(process=WORKER_PROCESS_NAME, pid=args.pid)
+
     # Setup Celery worker
     options = [
         "worker",
@@ -195,7 +170,7 @@ def worker(args):
         "--loglevel",
         celery_log_level,
         "--pidfile",
-        pid_file_path,
+        worker_pid_file_path,
     ]
     if autoscale:
         options.extend(["--autoscale", autoscale])
@@ -214,33 +189,31 @@ def worker(args):
         # executed.
         maybe_patch_concurrency(["-P", pool])
 
-    if args.daemon:
-        # Run Celery worker as daemon
-        handle = setup_logging(log_file)
-
-        with open(stdout, "a") as stdout_handle, open(stderr, "a") as 
stderr_handle:
-            if args.umask:
-                umask = args.umask
-            else:
-                umask = conf.get("celery", "worker_umask", 
fallback=settings.DAEMON_UMASK)
-
-            stdout_handle.truncate(0)
-            stderr_handle.truncate(0)
-
-            daemon_context = daemon.DaemonContext(
-                files_preserve=[handle],
-                umask=int(umask, 8),
-                stdout=stdout_handle,
-                stderr=stderr_handle,
-            )
-            with daemon_context, _serve_logs(skip_serve_logs):
-                celery_app.worker_main(options)
+    _, stdout, stderr, log_file = setup_locations(
+        process=WORKER_PROCESS_NAME,
+        stdout=args.stdout,
+        stderr=args.stderr,
+        log=args.log_file,
+    )
 
-    else:
-        # Run Celery worker in the same process
+    def run_celery_worker():
         with _serve_logs(skip_serve_logs):
             celery_app.worker_main(options)
 
+    if args.umask:
+        umask = args.umask
+    else:
+        umask = conf.get("celery", "worker_umask", 
fallback=settings.DAEMON_UMASK)
+
+    run_command_with_daemon_option(
+        args=args,
+        process_name=WORKER_PROCESS_NAME,
+        callback=run_celery_worker,
+        should_setup_logging=True,
+        umask=umask,
+        pid_file=worker_pid_file_path,
+    )
+
 
 @cli_utils.action_cli
 @providers_configuration_loaded
diff --git a/airflow/cli/commands/daemon_utils.py 
b/airflow/cli/commands/daemon_utils.py
new file mode 100644
index 0000000000..9184b1f7db
--- /dev/null
+++ b/airflow/cli/commands/daemon_utils.py
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import signal
+from argparse import Namespace
+from typing import Callable
+
+from daemon import daemon
+from daemon.pidfile import TimeoutPIDLockFile
+
+from airflow import settings
+from airflow.utils.cli import setup_locations, setup_logging, sigint_handler, 
sigquit_handler
+from airflow.utils.process_utils import check_if_pidfile_process_is_running
+
+
+def run_command_with_daemon_option(
+    *,
+    args: Namespace,
+    process_name: str,
+    callback: Callable,
+    should_setup_logging: bool = False,
+    umask: str = settings.DAEMON_UMASK,
+    pid_file: str | None = None,
+):
+    """Run the command in a daemon process if daemon mode enabled or within 
this process if not.
+
+    :param args: the set of arguments passed to the original CLI command
+    :param process_name: process name used in naming log and PID files for the 
daemon
+    :param callback: the actual command to run with or without daemon context
+    :param should_setup_logging: if true, then a log file handler for the 
daemon process will be created
+    :param umask: file access creation mask ("umask") to set for the process 
on daemon start
+    :param pid_file: if specified, this file path us used to store daemon 
process PID.
+        If not specified, a file path is generated with the default pattern.
+    """
+    if args.daemon:
+        pid, stdout, stderr, log_file = setup_locations(
+            process=process_name, stdout=args.stdout, stderr=args.stderr, 
log=args.log_file
+        )
+        if pid_file:
+            pid = pid_file
+
+        # Check if the process is already running; if not but a pidfile 
exists, clean it up
+        check_if_pidfile_process_is_running(pid_file=pid, 
process_name=process_name)
+
+        if should_setup_logging:
+            files_preserve = [setup_logging(log_file)]
+        else:
+            files_preserve = None
+        with open(stdout, "a") as stdout_handle, open(stderr, "a") as 
stderr_handle:
+            stdout_handle.truncate(0)
+            stderr_handle.truncate(0)
+
+            ctx = daemon.DaemonContext(
+                pidfile=TimeoutPIDLockFile(pid, -1),
+                files_preserve=files_preserve,
+                stdout=stdout_handle,
+                stderr=stderr_handle,
+                umask=int(umask, 8),
+            )
+
+            with ctx:
+                callback()
+    else:
+        signal.signal(signal.SIGINT, sigint_handler)
+        signal.signal(signal.SIGTERM, sigint_handler)
+        signal.signal(signal.SIGQUIT, sigquit_handler)
+        callback()
diff --git a/airflow/cli/commands/dag_processor_command.py 
b/airflow/cli/commands/dag_processor_command.py
index cf880f6622..85bef2727d 100644
--- a/airflow/cli/commands/dag_processor_command.py
+++ b/airflow/cli/commands/dag_processor_command.py
@@ -21,16 +21,12 @@ import logging
 from datetime import timedelta
 from typing import Any
 
-import daemon
-from daemon.pidfile import TimeoutPIDLockFile
-
-from airflow import settings
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
 from airflow.configuration import conf
 from airflow.dag_processing.manager import DagFileProcessorManager
 from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner
 from airflow.jobs.job import Job, run_job
 from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations, setup_logging
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
 
 log = logging.getLogger(__name__)
@@ -66,23 +62,9 @@ def dag_processor(args):
 
     job_runner = _create_dag_processor_job_runner(args)
 
-    if args.daemon:
-        pid, stdout, stderr, log_file = setup_locations(
-            "dag-processor", args.pid, args.stdout, args.stderr, args.log_file
-        )
-        handle = setup_logging(log_file)
-        with open(stdout, "a") as stdout_handle, open(stderr, "a") as 
stderr_handle:
-            stdout_handle.truncate(0)
-            stderr_handle.truncate(0)
-
-            ctx = daemon.DaemonContext(
-                pidfile=TimeoutPIDLockFile(pid, -1),
-                files_preserve=[handle],
-                stdout=stdout_handle,
-                stderr=stderr_handle,
-                umask=int(settings.DAEMON_UMASK, 8),
-            )
-            with ctx:
-                run_job(job=job_runner.job, 
execute_callable=job_runner._execute)
-    else:
-        run_job(job=job_runner.job, execute_callable=job_runner._execute)
+    run_command_with_daemon_option(
+        args=args,
+        process_name="dag-processor",
+        callback=lambda: run_job(job=job_runner.job, 
execute_callable=job_runner._execute),
+        should_setup_logging=True,
+    )
diff --git a/airflow/cli/commands/internal_api_command.py 
b/airflow/cli/commands/internal_api_command.py
index 73ed2e2501..f558c89cab 100644
--- a/airflow/cli/commands/internal_api_command.py
+++ b/airflow/cli/commands/internal_api_command.py
@@ -28,9 +28,7 @@ from pathlib import Path
 from tempfile import gettempdir
 from time import sleep
 
-import daemon
 import psutil
-from daemon.pidfile import TimeoutPIDLockFile
 from flask import Flask
 from flask_appbuilder import SQLA
 from flask_caching import Cache
@@ -40,14 +38,14 @@ from sqlalchemy.engine.url import make_url
 
 from airflow import settings
 from airflow.api_internal.internal_api_call import InternalApiConfig
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
 from airflow.cli.commands.webserver_command import GunicornMonitor
 from airflow.configuration import conf
 from airflow.exceptions import AirflowConfigException
 from airflow.logging_config import configure_logging
 from airflow.models import import_all_models
 from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations, setup_logging
-from airflow.utils.process_utils import check_if_pidfile_process_is_running
+from airflow.utils.cli import setup_locations
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
 from airflow.www.extensions.init_dagbag import init_dagbag
 from airflow.www.extensions.init_jinja_globals import init_jinja_globals
@@ -81,13 +79,6 @@ def internal_api(args):
             host=args.hostname,
         )
     else:
-        pid_file, stdout, stderr, log_file = setup_locations(
-            "internal-api", args.pid, args.stdout, args.stderr, args.log_file
-        )
-
-        # Check if Internal APi is already running if not, remove old pidfile
-        check_if_pidfile_process_is_running(pid_file=pid_file, 
process_name="internal-api")
-
         log.info(
             textwrap.dedent(
                 f"""\
@@ -101,6 +92,8 @@ def internal_api(args):
             )
         )
 
+        pid_file, _, _, _ = setup_locations("internal-api", pid=args.pid)
+
         run_args = [
             sys.executable,
             "-m",
@@ -137,25 +130,27 @@ def internal_api(args):
         # then have a copy of the app
         run_args += ["--preload"]
 
-        gunicorn_master_proc: psutil.Process | None = None
-
-        def kill_proc(signum, _):
+        def kill_proc(signum: int, gunicorn_master_proc: psutil.Process | 
subprocess.Popen):
             log.info("Received signal: %s. Closing gunicorn.", signum)
             gunicorn_master_proc.terminate()
             with suppress(TimeoutError):
                 gunicorn_master_proc.wait(timeout=30)
-            if gunicorn_master_proc.is_running():
+            if isinstance(gunicorn_master_proc, subprocess.Popen):
+                still_running = gunicorn_master_proc.poll() is not None
+            else:
+                still_running = gunicorn_master_proc.is_running()
+            if still_running:
                 gunicorn_master_proc.kill()
             sys.exit(0)
 
-        def monitor_gunicorn(gunicorn_master_pid: int):
+        def monitor_gunicorn(gunicorn_master_proc: psutil.Process | 
subprocess.Popen):
             # Register signal handlers
-            signal.signal(signal.SIGINT, kill_proc)
-            signal.signal(signal.SIGTERM, kill_proc)
+            signal.signal(signal.SIGINT, lambda signum, _: kill_proc(signum, 
gunicorn_master_proc))
+            signal.signal(signal.SIGTERM, lambda signum, _: kill_proc(signum, 
gunicorn_master_proc))
 
             # These run forever until SIG{INT, TERM, KILL, ...} signal is sent
             GunicornMonitor(
-                gunicorn_master_pid=gunicorn_master_pid,
+                gunicorn_master_pid=gunicorn_master_proc.pid,
                 num_workers_expected=num_workers,
                 master_timeout=120,
                 worker_refresh_interval=30,
@@ -163,45 +158,39 @@ def internal_api(args):
                 reload_on_plugin_change=False,
             ).start()
 
+        def start_and_monitor_gunicorn(args):
+            if args.daemon:
+                subprocess.Popen(run_args, close_fds=True)
+
+                # Reading pid of gunicorn master as it will be different that
+                # the one of process spawned above.
+                gunicorn_master_proc_pid = None
+                while not gunicorn_master_proc_pid:
+                    sleep(0.1)
+                    gunicorn_master_proc_pid = read_pid_from_pidfile(pid_file)
+
+                # Run Gunicorn monitor
+                gunicorn_master_proc = psutil.Process(gunicorn_master_proc_pid)
+                monitor_gunicorn(gunicorn_master_proc)
+            else:
+                with subprocess.Popen(run_args, close_fds=True) as 
gunicorn_master_proc:
+                    monitor_gunicorn(gunicorn_master_proc)
+
         if args.daemon:
             # This makes possible errors get reported before daemonization
             os.environ["SKIP_DAGS_PARSING"] = "True"
-            app = create_app(None)
+            create_app(None)
             os.environ.pop("SKIP_DAGS_PARSING")
 
-            handle = setup_logging(log_file)
-
-            pid_path = Path(pid_file)
-            pidlock_path = 
pid_path.with_name(f"{pid_path.stem}-monitor{pid_path.suffix}")
-
-            with open(stdout, "a") as stdout, open(stderr, "a") as stderr:
-                stdout.truncate(0)
-                stderr.truncate(0)
-
-                ctx = daemon.DaemonContext(
-                    pidfile=TimeoutPIDLockFile(pidlock_path, -1),
-                    files_preserve=[handle],
-                    stdout=stdout,
-                    stderr=stderr,
-                    umask=int(settings.DAEMON_UMASK, 8),
-                )
-                with ctx:
-                    subprocess.Popen(run_args, close_fds=True)
-
-                    # Reading pid of gunicorn main process as it will be 
different that
-                    # the one of process spawned above.
-                    gunicorn_master_proc_pid = None
-                    while not gunicorn_master_proc_pid:
-                        sleep(0.1)
-                        gunicorn_master_proc_pid = 
read_pid_from_pidfile(pid_file)
-
-                    # Run Gunicorn monitor
-                    gunicorn_master_proc = 
psutil.Process(gunicorn_master_proc_pid)
-                    monitor_gunicorn(gunicorn_master_proc.pid)
-
-        else:
-            with subprocess.Popen(run_args, close_fds=True) as 
gunicorn_master_proc:
-                monitor_gunicorn(gunicorn_master_proc.pid)
+        pid_file_path = Path(pid_file)
+        monitor_pid_file = 
str(pid_file_path.with_name(f"{pid_file_path.stem}-monitor{pid_file_path.suffix}"))
+        run_command_with_daemon_option(
+            args=args,
+            process_name="internal-api",
+            callback=lambda: start_and_monitor_gunicorn(args),
+            should_setup_logging=True,
+            pid_file=monitor_pid_file,
+        )
 
 
 def create_app(config=None, testing=False):
diff --git a/airflow/cli/commands/kerberos_command.py 
b/airflow/cli/commands/kerberos_command.py
index 4dd63d52eb..8d33e7f8ef 100644
--- a/airflow/cli/commands/kerberos_command.py
+++ b/airflow/cli/commands/kerberos_command.py
@@ -17,13 +17,10 @@
 """Kerberos command."""
 from __future__ import annotations
 
-import daemon
-from daemon.pidfile import TimeoutPIDLockFile
-
 from airflow import settings
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
 from airflow.security import kerberos as krb
 from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
 
 
@@ -33,22 +30,8 @@ def kerberos(args):
     """Start a kerberos ticket renewer."""
     print(settings.HEADER)
 
-    if args.daemon:
-        pid, stdout, stderr, _ = setup_locations(
-            "kerberos", args.pid, args.stdout, args.stderr, args.log_file
-        )
-        with open(stdout, "a") as stdout_handle, open(stderr, "a") as 
stderr_handle:
-            stdout_handle.truncate(0)
-            stderr_handle.truncate(0)
-
-            ctx = daemon.DaemonContext(
-                pidfile=TimeoutPIDLockFile(pid, -1),
-                stdout=stdout_handle,
-                stderr=stderr_handle,
-                umask=int(settings.DAEMON_UMASK, 8),
-            )
-
-            with ctx:
-                krb.run(principal=args.principal, keytab=args.keytab)
-    else:
-        krb.run(principal=args.principal, keytab=args.keytab)
+    run_command_with_daemon_option(
+        args=args,
+        process_name="kerberos",
+        callback=lambda: krb.run(principal=args.principal, keytab=args.keytab),
+    )
diff --git a/airflow/cli/commands/scheduler_command.py 
b/airflow/cli/commands/scheduler_command.py
index fd25951ad3..fef0b97b2d 100644
--- a/airflow/cli/commands/scheduler_command.py
+++ b/airflow/cli/commands/scheduler_command.py
@@ -18,31 +18,33 @@
 from __future__ import annotations
 
 import logging
-import signal
+from argparse import Namespace
 from contextlib import contextmanager
 from multiprocessing import Process
 
-import daemon
-from daemon.pidfile import TimeoutPIDLockFile
-
 from airflow import settings
 from airflow.api_internal.internal_api_call import InternalApiConfig
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
 from airflow.configuration import conf
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.jobs.job import Job, run_job
 from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
 from airflow.utils import cli as cli_utils
-from airflow.utils.cli import process_subdir, setup_locations, setup_logging, 
sigint_handler, sigquit_handler
+from airflow.utils.cli import process_subdir
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
 from airflow.utils.scheduler_health import serve_health_check
 
 log = logging.getLogger(__name__)
 
 
-def _run_scheduler_job(job_runner: SchedulerJobRunner, *, skip_serve_logs: 
bool) -> None:
+def _run_scheduler_job(args) -> None:
+    job_runner = SchedulerJobRunner(
+        job=Job(), subdir=process_subdir(args.subdir), num_runs=args.num_runs, 
do_pickle=args.do_pickle
+    )
+    
ExecutorLoader.validate_database_executor_compatibility(job_runner.job.executor)
     InternalApiConfig.force_database_direct_access()
     enable_health_check = conf.getboolean("scheduler", "ENABLE_HEALTH_CHECK")
-    with _serve_logs(skip_serve_logs), 
_serve_health_check(enable_health_check):
+    with _serve_logs(args.skip_serve_logs), 
_serve_health_check(enable_health_check):
         try:
             run_job(job=job_runner.job, execute_callable=job_runner._execute)
         except Exception:
@@ -51,38 +53,16 @@ def _run_scheduler_job(job_runner: SchedulerJobRunner, *, 
skip_serve_logs: bool)
 
 @cli_utils.action_cli
 @providers_configuration_loaded
-def scheduler(args):
+def scheduler(args: Namespace):
     """Start Airflow Scheduler."""
     print(settings.HEADER)
 
-    job_runner = SchedulerJobRunner(
-        job=Job(), subdir=process_subdir(args.subdir), num_runs=args.num_runs, 
do_pickle=args.do_pickle
+    run_command_with_daemon_option(
+        args=args,
+        process_name="scheduler",
+        callback=lambda: _run_scheduler_job(args),
+        should_setup_logging=True,
     )
-    
ExecutorLoader.validate_database_executor_compatibility(job_runner.job.executor)
-
-    if args.daemon:
-        pid, stdout, stderr, log_file = setup_locations(
-            "scheduler", args.pid, args.stdout, args.stderr, args.log_file
-        )
-        handle = setup_logging(log_file)
-        with open(stdout, "a") as stdout_handle, open(stderr, "a") as 
stderr_handle:
-            stdout_handle.truncate(0)
-            stderr_handle.truncate(0)
-
-            ctx = daemon.DaemonContext(
-                pidfile=TimeoutPIDLockFile(pid, -1),
-                files_preserve=[handle],
-                stdout=stdout_handle,
-                stderr=stderr_handle,
-                umask=int(settings.DAEMON_UMASK, 8),
-            )
-            with ctx:
-                _run_scheduler_job(job_runner, 
skip_serve_logs=args.skip_serve_logs)
-    else:
-        signal.signal(signal.SIGINT, sigint_handler)
-        signal.signal(signal.SIGTERM, sigint_handler)
-        signal.signal(signal.SIGQUIT, sigquit_handler)
-        _run_scheduler_job(job_runner, skip_serve_logs=args.skip_serve_logs)
 
 
 @contextmanager
diff --git a/airflow/cli/commands/triggerer_command.py 
b/airflow/cli/commands/triggerer_command.py
index 5ddb4e23b6..3479480dbf 100644
--- a/airflow/cli/commands/triggerer_command.py
+++ b/airflow/cli/commands/triggerer_command.py
@@ -17,21 +17,17 @@
 """Triggerer command."""
 from __future__ import annotations
 
-import signal
 from contextlib import contextmanager
 from functools import partial
 from multiprocessing import Process
 from typing import Generator
 
-import daemon
-from daemon.pidfile import TimeoutPIDLockFile
-
 from airflow import settings
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
 from airflow.configuration import conf
 from airflow.jobs.job import Job, run_job
 from airflow.jobs.triggerer_job_runner import TriggererJobRunner
 from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations, setup_logging, sigint_handler, 
sigquit_handler
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
 from airflow.utils.serve_logs import serve_logs
 
@@ -51,6 +47,12 @@ def _serve_logs(skip_serve_logs: bool = False) -> 
Generator[None, None, None]:
             sub_proc.terminate()
 
 
+def triggerer_run(skip_serve_logs: bool, capacity: int, triggerer_heartrate: 
float):
+    with _serve_logs(skip_serve_logs):
+        triggerer_job_runner = 
TriggererJobRunner(job=Job(heartrate=triggerer_heartrate), capacity=capacity)
+        run_job(job=triggerer_job_runner.job, 
execute_callable=triggerer_job_runner._execute)
+
+
 @cli_utils.action_cli
 @providers_configuration_loaded
 def triggerer(args):
@@ -59,33 +61,9 @@ def triggerer(args):
     print(settings.HEADER)
     triggerer_heartrate = conf.getfloat("triggerer", "JOB_HEARTBEAT_SEC")
 
-    if args.daemon:
-        pid, stdout, stderr, log_file = setup_locations(
-            "triggerer", args.pid, args.stdout, args.stderr, args.log_file
-        )
-        handle = setup_logging(log_file)
-        with open(stdout, "a") as stdout_handle, open(stderr, "a") as 
stderr_handle:
-            stdout_handle.truncate(0)
-            stderr_handle.truncate(0)
-
-            daemon_context = daemon.DaemonContext(
-                pidfile=TimeoutPIDLockFile(pid, -1),
-                files_preserve=[handle],
-                stdout=stdout_handle,
-                stderr=stderr_handle,
-                umask=int(settings.DAEMON_UMASK, 8),
-            )
-            with daemon_context, _serve_logs(args.skip_serve_logs):
-                triggerer_job_runner = TriggererJobRunner(
-                    job=Job(heartrate=triggerer_heartrate), 
capacity=args.capacity
-                )
-                run_job(job=triggerer_job_runner.job, 
execute_callable=triggerer_job_runner._execute)
-    else:
-        signal.signal(signal.SIGINT, sigint_handler)
-        signal.signal(signal.SIGTERM, sigint_handler)
-        signal.signal(signal.SIGQUIT, sigquit_handler)
-        with _serve_logs(args.skip_serve_logs):
-            triggerer_job_runner = TriggererJobRunner(
-                job=Job(heartrate=triggerer_heartrate), capacity=args.capacity
-            )
-            run_job(job=triggerer_job_runner.job, 
execute_callable=triggerer_job_runner._execute)
+    run_command_with_daemon_option(
+        args=args,
+        process_name="triggerer",
+        callback=lambda: triggerer_run(args.skip_serve_logs, args.capacity, 
triggerer_heartrate),
+        should_setup_logging=True,
+    )
diff --git a/airflow/cli/commands/webserver_command.py 
b/airflow/cli/commands/webserver_command.py
index 5ae601b428..4cb7939fd7 100644
--- a/airflow/cli/commands/webserver_command.py
+++ b/airflow/cli/commands/webserver_command.py
@@ -27,26 +27,21 @@ import time
 from contextlib import suppress
 from pathlib import Path
 from time import sleep
-from typing import TYPE_CHECKING, NoReturn
+from typing import NoReturn
 
-import daemon
 import psutil
-from daemon.pidfile import TimeoutPIDLockFile
 from lockfile.pidlockfile import read_pid_from_pidfile
 
 from airflow import settings
+from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, AirflowWebServerTimeout
 from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations, setup_logging
+from airflow.utils.cli import setup_locations
 from airflow.utils.hashlib_wrapper import md5
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.process_utils import check_if_pidfile_process_is_running
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
 
-if TYPE_CHECKING:
-    import types
-
 log = logging.getLogger(__name__)
 
 
@@ -367,13 +362,6 @@ def webserver(args):
             ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None,
         )
     else:
-        pid_file, stdout, stderr, log_file = setup_locations(
-            "webserver", args.pid, args.stdout, args.stderr, args.log_file
-        )
-
-        # Check if webserver is already running if not, remove old pidfile
-        check_if_pidfile_process_is_running(pid_file=pid_file, 
process_name="webserver")
-
         print(
             textwrap.dedent(
                 f"""\
@@ -387,6 +375,7 @@ def webserver(args):
             )
         )
 
+        pid_file, _, _, _ = setup_locations("webserver", pid=args.pid)
         run_args = [
             sys.executable,
             "-m",
@@ -436,9 +425,7 @@ def webserver(args):
             # all writing to the database at the same time, we use the 
--preload option.
             run_args += ["--preload"]
 
-        gunicorn_master_proc: psutil.Process | subprocess.Popen
-
-        def kill_proc(signum: int, frame: types.FrameType | None) -> NoReturn:
+        def kill_proc(signum: int, gunicorn_master_proc: psutil.Process | 
subprocess.Popen) -> NoReturn:
             log.info("Received signal: %s. Closing gunicorn.", signum)
             gunicorn_master_proc.terminate()
             with suppress(TimeoutError):
@@ -451,14 +438,14 @@ def webserver(args):
                 gunicorn_master_proc.kill()
             sys.exit(0)
 
-        def monitor_gunicorn(gunicorn_master_pid: int) -> NoReturn:
+        def monitor_gunicorn(gunicorn_master_proc: psutil.Process | 
subprocess.Popen) -> NoReturn:
             # Register signal handlers
-            signal.signal(signal.SIGINT, kill_proc)
-            signal.signal(signal.SIGTERM, kill_proc)
+            signal.signal(signal.SIGINT, lambda signum, _: kill_proc(signum, 
gunicorn_master_proc))
+            signal.signal(signal.SIGTERM, lambda signum, _: kill_proc(signum, 
gunicorn_master_proc))
 
             # These run forever until SIG{INT, TERM, KILL, ...} signal is sent
             GunicornMonitor(
-                gunicorn_master_pid=gunicorn_master_pid,
+                gunicorn_master_pid=gunicorn_master_proc.pid,
                 num_workers_expected=num_workers,
                 master_timeout=conf.getint("webserver", 
"web_server_master_timeout"),
                 worker_refresh_interval=conf.getint("webserver", 
"worker_refresh_interval", fallback=30),
@@ -468,42 +455,36 @@ def webserver(args):
                 ),
             ).start()
 
+        def start_and_monitor_gunicorn(args):
+            if args.daemon:
+                subprocess.Popen(run_args, close_fds=True)
+
+                # Reading pid of gunicorn master as it will be different that
+                # the one of process spawned above.
+                gunicorn_master_proc_pid = None
+                while not gunicorn_master_proc_pid:
+                    sleep(0.1)
+                    gunicorn_master_proc_pid = read_pid_from_pidfile(pid_file)
+
+                # Run Gunicorn monitor
+                gunicorn_master_proc = psutil.Process(gunicorn_master_proc_pid)
+                monitor_gunicorn(gunicorn_master_proc)
+            else:
+                with subprocess.Popen(run_args, close_fds=True) as 
gunicorn_master_proc:
+                    monitor_gunicorn(gunicorn_master_proc)
+
         if args.daemon:
             # This makes possible errors get reported before daemonization
             os.environ["SKIP_DAGS_PARSING"] = "True"
-            app = create_app(None)
+            create_app(None)
             os.environ.pop("SKIP_DAGS_PARSING")
 
-            handle = setup_logging(log_file)
-
-            pid_path = Path(pid_file)
-            pidlock_path = 
pid_path.with_name(f"{pid_path.stem}-monitor{pid_path.suffix}")
-
-            with open(stdout, "a") as stdout, open(stderr, "a") as stderr:
-                stdout.truncate(0)
-                stderr.truncate(0)
-
-                ctx = daemon.DaemonContext(
-                    pidfile=TimeoutPIDLockFile(pidlock_path, -1),
-                    files_preserve=[handle],
-                    stdout=stdout,
-                    stderr=stderr,
-                    umask=int(settings.DAEMON_UMASK, 8),
-                )
-                with ctx:
-                    subprocess.Popen(run_args, close_fds=True)
-
-                    # Reading pid of gunicorn master as it will be different 
that
-                    # the one of process spawned above.
-                    gunicorn_master_proc_pid = None
-                    while not gunicorn_master_proc_pid:
-                        sleep(0.1)
-                        gunicorn_master_proc_pid = 
read_pid_from_pidfile(pid_file)
-
-                    # Run Gunicorn monitor
-                    gunicorn_master_proc = 
psutil.Process(gunicorn_master_proc_pid)
-                    monitor_gunicorn(gunicorn_master_proc.pid)
-
-        else:
-            with subprocess.Popen(run_args, close_fds=True) as 
gunicorn_master_proc:
-                monitor_gunicorn(gunicorn_master_proc.pid)
+        pid_file_path = Path(pid_file)
+        monitor_pid_file = 
str(pid_file_path.with_name(f"{pid_file_path.stem}-monitor{pid_file_path.suffix}"))
+        run_command_with_daemon_option(
+            args=args,
+            process_name="webserver",
+            callback=lambda: start_and_monitor_gunicorn(args),
+            should_setup_logging=True,
+            pid_file=monitor_pid_file,
+        )
diff --git a/tests/cli/commands/test_celery_command.py 
b/tests/cli/commands/test_celery_command.py
index ae968f1171..02f26d7f23 100644
--- a/tests/cli/commands/test_celery_command.py
+++ b/tests/cli/commands/test_celery_command.py
@@ -266,9 +266,9 @@ class TestFlowerCommand:
             ]
         )
 
-    @mock.patch("airflow.cli.commands.celery_command.TimeoutPIDLockFile")
-    @mock.patch("airflow.cli.commands.celery_command.setup_locations")
-    @mock.patch("airflow.cli.commands.celery_command.daemon")
+    @mock.patch("airflow.cli.commands.daemon_utils.TimeoutPIDLockFile")
+    @mock.patch("airflow.cli.commands.daemon_utils.setup_locations")
+    @mock.patch("airflow.cli.commands.daemon_utils.daemon")
     @mock.patch("airflow.providers.celery.executors.celery_executor.app")
     def test_run_command_daemon(self, mock_celery_app, mock_daemon, 
mock_setup_locations, mock_pid_file):
         mock_setup_locations.return_value = (
@@ -305,7 +305,7 @@ class TestFlowerCommand:
             ]
         )
         mock_open = mock.mock_open()
-        with mock.patch("airflow.cli.commands.celery_command.open", mock_open):
+        with mock.patch("airflow.cli.commands.daemon_utils.open", mock_open):
             celery_command.flower(args)
 
         mock_celery_app.start.assert_called_once_with(
@@ -320,11 +320,12 @@ class TestFlowerCommand:
                 "--conf=flower_config",
             ]
         )
-        assert mock_daemon.mock_calls == [
+        assert mock_daemon.mock_calls[:3] == [
             mock.call.DaemonContext(
                 pidfile=mock_pid_file.return_value,
-                stderr=mock_open.return_value,
+                files_preserve=None,
                 stdout=mock_open.return_value,
+                stderr=mock_open.return_value,
                 umask=0o077,
             ),
             mock.call.DaemonContext().__enter__(),
@@ -333,11 +334,10 @@ class TestFlowerCommand:
 
         assert mock_setup_locations.mock_calls == [
             mock.call(
-                log="/tmp/flower.log",
-                pid="/tmp/flower.pid",
                 process="flower",
-                stderr="/tmp/flower-stderr.log",
                 stdout="/tmp/flower-stdout.log",
+                stderr="/tmp/flower-stderr.log",
+                log="/tmp/flower.log",
             )
         ]
         
mock_pid_file.assert_has_calls([mock.call(mock_setup_locations.return_value[0], 
-1)])
diff --git a/tests/cli/commands/test_kerberos_command.py 
b/tests/cli/commands/test_kerberos_command.py
index 41dce045fa..14eb1676bd 100644
--- a/tests/cli/commands/test_kerberos_command.py
+++ b/tests/cli/commands/test_kerberos_command.py
@@ -36,9 +36,9 @@ class TestKerberosCommand:
         kerberos_command.kerberos(args)
         mock_krb.run.assert_called_once_with(keytab="/tmp/airflow.keytab", 
principal="PRINCIPAL")
 
-    @mock.patch("airflow.cli.commands.kerberos_command.TimeoutPIDLockFile")
-    @mock.patch("airflow.cli.commands.kerberos_command.setup_locations")
-    @mock.patch("airflow.cli.commands.kerberos_command.daemon")
+    @mock.patch("airflow.cli.commands.daemon_utils.TimeoutPIDLockFile")
+    @mock.patch("airflow.cli.commands.daemon_utils.setup_locations")
+    @mock.patch("airflow.cli.commands.daemon_utils.daemon")
     @mock.patch("airflow.cli.commands.kerberos_command.krb")
     @conf_vars({("core", "executor"): "CeleryExecutor"})
     def test_run_command_daemon(self, mock_krb, mock_daemon, 
mock_setup_locations, mock_pid_file):
@@ -66,13 +66,14 @@ class TestKerberosCommand:
             ]
         )
         mock_open = mock.mock_open()
-        with mock.patch("airflow.cli.commands.kerberos_command.open", 
mock_open):
+        with mock.patch("airflow.cli.commands.daemon_utils.open", mock_open):
             kerberos_command.kerberos(args)
 
         mock_krb.run.assert_called_once_with(keytab="/tmp/airflow.keytab", 
principal="PRINCIPAL")
-        assert mock_daemon.mock_calls == [
+        assert mock_daemon.mock_calls[:3] == [
             mock.call.DaemonContext(
                 pidfile=mock_pid_file.return_value,
+                files_preserve=None,
                 stderr=mock_open.return_value,
                 stdout=mock_open.return_value,
                 umask=0o077,
@@ -81,18 +82,14 @@ class TestKerberosCommand:
             mock.call.DaemonContext().__exit__(None, None, None),
         ]
 
-        mock_setup_locations.assert_has_calls(
-            [
-                mock.call(
-                    "kerberos",
-                    "/tmp/kerberos.pid",
-                    "/tmp/kerberos-stdout.log",
-                    "/tmp/kerberos-stderr.log",
-                    "/tmp/kerberos.log",
-                )
-            ]
+        assert mock_setup_locations.mock_calls[0] == mock.call(
+            process="kerberos",
+            stdout="/tmp/kerberos-stdout.log",
+            stderr="/tmp/kerberos-stderr.log",
+            log="/tmp/kerberos.log",
         )
-        
mock_pid_file.assert_has_calls([mock.call(mock_setup_locations.return_value[0], 
-1)])
+
+        mock_pid_file.mock_calls[0] = 
mock.call(mock_setup_locations.return_value[0], -1)
         assert mock_open.mock_calls == [
             mock.call(mock_setup_locations.return_value[1], "a"),
             mock.call().__enter__(),


Reply via email to