This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 05a67efe32 Add an option to load the dags from db for command tasks run (#32038) 05a67efe32 is described below commit 05a67efe32af248ca191ea59815b3b202f893f46 Author: Hussein Awala <huss...@awala.fr> AuthorDate: Sat Jun 24 00:31:05 2023 +0200 Add an option to load the dags from db for command tasks run (#32038) Signed-off-by: Hussein Awala <huss...@awala.fr> --- airflow/cli/cli_config.py | 2 ++ airflow/cli/commands/task_command.py | 2 +- airflow/utils/cli.py | 21 ++++++++++++------- tests/cli/commands/test_task_command.py | 36 +++++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 8 deletions(-) diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py index 0c69571fea..587934769e 100644 --- a/airflow/cli/cli_config.py +++ b/airflow/cli/cli_config.py @@ -604,6 +604,7 @@ ARG_PICKLE = Arg(("-p", "--pickle"), help="Serialized pickle object of the entir ARG_JOB_ID = Arg(("-j", "--job-id"), help=argparse.SUPPRESS) ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg") ARG_MAP_INDEX = Arg(("--map-index",), type=int, default=-1, help="Mapped task index") +ARG_READ_FROM_DB = Arg(("--read-from-db",), help="Read dag from DB instead of dag file", action="store_true") # database @@ -1453,6 +1454,7 @@ TASKS_COMMANDS = ( ARG_SHUT_DOWN_LOGGING, ARG_MAP_INDEX, ARG_VERBOSE, + ARG_READ_FROM_DB, ), ), ActionCommand( diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index aab8bb10dc..560764536b 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -398,7 +398,7 @@ def task_run(args, dag: DAG | None = None) -> TaskReturnCode | None: print(f"Loading pickle id: {args.pickle}") _dag = get_dag_by_pickle(args.pickle) elif not dag: - _dag = get_dag(args.subdir, args.dag_id) + _dag = get_dag(args.subdir, args.dag_id, args.read_from_db) else: _dag = dag task = _dag.get_task(task_id=args.task_id) diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index d9e53ac072..56ac166b54 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -215,27 +215,34 @@ def _search_for_dag_file(val: str | None) -> str | None: return None -def get_dag(subdir: str | None, dag_id: str) -> DAG: +def get_dag(subdir: str | None, dag_id: str, from_db: bool = False) -> DAG: """ Returns DAG of a given dag_id. - First it we'll try to use the given subdir. If that doesn't work, we'll try to + First we'll try to use the given subdir. If that doesn't work, we'll try to find the correct path (assuming it's a file) and failing that, use the configured dags folder. """ from airflow.models import DagBag - first_path = process_subdir(subdir) - dagbag = DagBag(first_path) - if dag_id not in dagbag.dags: + if from_db: + dagbag = DagBag(read_dags_from_db=True) + else: + first_path = process_subdir(subdir) + dagbag = DagBag(first_path) + dag = dagbag.get_dag(dag_id) + if not dag: + if from_db: + raise AirflowException(f"Dag {dag_id!r} could not be found in DagBag read from database.") fallback_path = _search_for_dag_file(subdir) or settings.DAGS_FOLDER logger.warning("Dag %r not found in path %s; trying path %s", dag_id, first_path, fallback_path) dagbag = DagBag(dag_folder=fallback_path) - if dag_id not in dagbag.dags: + dag = dagbag.get_dag(dag_id) + if not dag: raise AirflowException( f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse." ) - return dagbag.dags[dag_id] + return dag def get_dags(subdir: str | None, dag_id: str, use_regex: bool = False): diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index 376faeda41..646d76f47c 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -276,6 +276,42 @@ class TestCliTasks: external_executor_id=None, ) + @pytest.mark.parametrize( + "from_db", + [True, False], + ) + @mock.patch("airflow.cli.commands.task_command.LocalTaskJobRunner") + def test_run_with_read_from_db(self, mock_local_job_runner, caplog, from_db): + """ + Test that we can run with read from db + """ + task0_id = self.dag.task_ids[0] + args0 = [ + "tasks", + "run", + "--ignore-all-dependencies", + "--local", + self.dag_id, + task0_id, + self.run_id, + ] + (["--read-from-db"] if from_db else []) + mock_local_job_runner.return_value.job_type = "LocalTaskJob" + task_command.task_run(self.parser.parse_args(args0)) + mock_local_job_runner.assert_called_once_with( + job=mock.ANY, + task_instance=mock.ANY, + mark_success=False, + ignore_all_deps=True, + ignore_depends_on_past=False, + wait_for_past_depends_before_skipping=False, + ignore_task_deps=False, + ignore_ti_state=False, + pickle_id=None, + pool=None, + external_executor_id=None, + ) + assert ("Filling up the DagBag from" in caplog.text) != from_db + @mock.patch("airflow.cli.commands.task_command.LocalTaskJobRunner") def test_run_raises_when_theres_no_dagrun(self, mock_local_job): """