This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f1301daef0 Improve stability of remove_task_decorator function (#38649)
f1301daef0 is described below

commit f1301daef027a750f4060e0f26d53151af99d5f7
Author: rom sharon <33751805+romsharo...@users.noreply.github.com>
AuthorDate: Mon Apr 1 20:10:31 2024 +0300

    Improve stability of remove_task_decorator function (#38649)
    
    * Improve stability of remove_task_decorator function
    
    * fix statics
    
    * test
    
    * remove test
    
    ---------
    
    Co-authored-by: Sam Wheating <samwheat...@gmail.com>
---
 airflow/utils/decorators.py           | 10 ++++++++--
 tests/utils/test_python_virtualenv.py |  5 +++++
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py
index 21b6ff3412..77a5eddaf0 100644
--- a/airflow/utils/decorators.py
+++ b/airflow/utils/decorators.py
@@ -62,12 +62,18 @@ def remove_task_decorator(python_source: str, 
task_decorator_name: str) -> str:
 
     :param python_source: The python source code
     :param task_decorator_name: the decorator name
+
+    TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse
     """
 
     def _remove_task_decorator(py_source, decorator_name):
-        if decorator_name not in py_source:
+        # if no line starts with @decorator_name, we can early exit
+        for line in py_source.split("\n"):
+            if line.startswith(decorator_name):
+                break
+        else:
             return python_source
-        split = python_source.split(decorator_name)
+        split = python_source.split(decorator_name, 1)
         before_decorator, after_decorator = split[0], split[1]
         if after_decorator[0] == "(":
             after_decorator = _balance_parens(after_decorator)
diff --git a/tests/utils/test_python_virtualenv.py 
b/tests/utils/test_python_virtualenv.py
index 8f067fd3ae..38cda4854b 100644
--- a/tests/utils/test_python_virtualenv.py
+++ b/tests/utils/test_python_virtualenv.py
@@ -125,6 +125,11 @@ class TestPrepareVirtualenv:
         res = remove_task_decorator(python_source=py_source, 
task_decorator_name="@task.virtualenv")
         assert res == "def f():\nimport funcsigs"
 
+    def test_remove_decorator_including_comment(self):
+        py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\nimport 
funcsigs"
+        res = remove_task_decorator(python_source=py_source, 
task_decorator_name="@task.virtualenv")
+        assert res == "def f():\n# @task.virtualenv\nimport funcsigs"
+
     def test_remove_decorator_nested(self):
         py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\nimport funcsigs"
         res = remove_task_decorator(python_source=py_source, 
task_decorator_name="@task.virtualenv")

Reply via email to