This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new f1301daef0 Improve stability of remove_task_decorator function (#38649) f1301daef0 is described below commit f1301daef027a750f4060e0f26d53151af99d5f7 Author: rom sharon <33751805+romsharo...@users.noreply.github.com> AuthorDate: Mon Apr 1 20:10:31 2024 +0300 Improve stability of remove_task_decorator function (#38649) * Improve stability of remove_task_decorator function * fix statics * test * remove test --------- Co-authored-by: Sam Wheating <samwheat...@gmail.com> --- airflow/utils/decorators.py | 10 ++++++++-- tests/utils/test_python_virtualenv.py | 5 +++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index 21b6ff3412..77a5eddaf0 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -62,12 +62,18 @@ def remove_task_decorator(python_source: str, task_decorator_name: str) -> str: :param python_source: The python source code :param task_decorator_name: the decorator name + + TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse """ def _remove_task_decorator(py_source, decorator_name): - if decorator_name not in py_source: + # if no line starts with @decorator_name, we can early exit + for line in py_source.split("\n"): + if line.startswith(decorator_name): + break + else: return python_source - split = python_source.split(decorator_name) + split = python_source.split(decorator_name, 1) before_decorator, after_decorator = split[0], split[1] if after_decorator[0] == "(": after_decorator = _balance_parens(after_decorator) diff --git a/tests/utils/test_python_virtualenv.py b/tests/utils/test_python_virtualenv.py index 8f067fd3ae..38cda4854b 100644 --- a/tests/utils/test_python_virtualenv.py +++ b/tests/utils/test_python_virtualenv.py @@ -125,6 +125,11 @@ class TestPrepareVirtualenv: res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") assert res == "def f():\nimport funcsigs" + def test_remove_decorator_including_comment(self): + py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\nimport funcsigs" + res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") + assert res == "def f():\n# @task.virtualenv\nimport funcsigs" + def test_remove_decorator_nested(self): py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")