This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f1a53801a4 Support `@task.bash` with Task SDK (#48060)
8f1a53801a4 is described below

commit 8f1a53801a4da94fb81f65c11dcccf74601e1859
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Mar 22 00:28:55 2025 +0530

    Support `@task.bash` with Task SDK (#48060)
---
 .pre-commit-config.yaml                            |  2 +-
 airflow-core/src/airflow/decorators/bash.py        |  8 +++-
 airflow-core/tests/unit/decorators/test_bash.py    | 52 ++++++++++++----------
 .../airflow/providers/standard/operators/bash.py   | 52 +++-------------------
 .../tests/unit/standard/operators/test_bash.py     |  2 -
 .../check_base_operator_partial_arguments.py       |  1 +
 .../src/airflow/sdk/definitions/_internal/types.py | 12 +++++
 .../src/airflow/sdk/definitions/baseoperator.py    |  5 +++
 .../src/airflow/sdk/execution_time/task_runner.py  | 10 +++++
 .../task_sdk/execution_time/test_task_runner.py    | 21 +++++++++
 10 files changed, 90 insertions(+), 75 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 84dbf743439..36cec5f5a4a 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1356,7 +1356,7 @@ repos:
         name: Check templated fields mapped in operators/sensors
         language: python
         entry: ./scripts/ci/pre_commit/check_template_fields.py
-        files: ^(providers/.*/)?airflow/.*/(sensors|operators)/.*\.py$
+        files: ^(providers/.*/)?airflow-core/.*/(sensors|operators)/.*\.py$
         additional_dependencies: [ 'rich>=12.4.4' ]
         require_serial: true
       - id: update-migration-references
diff --git a/airflow-core/src/airflow/decorators/bash.py 
b/airflow-core/src/airflow/decorators/bash.py
index 996ac5ffe05..a82575ce3ac 100644
--- a/airflow-core/src/airflow/decorators/bash.py
+++ b/airflow-core/src/airflow/decorators/bash.py
@@ -23,9 +23,9 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar
 
 from airflow.decorators.base import DecoratedOperator, TaskDecorator, 
task_decorator_factory
 from airflow.providers.standard.operators.bash import BashOperator
+from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
 from airflow.utils.context import context_merge
 from airflow.utils.operator_helpers import determine_kwargs
-from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
     from airflow.sdk.definitions.context import Context
@@ -49,6 +49,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator):
     }
 
     custom_operator_name: str = "@task.bash"
+    overwrite_rtif_after_execution: bool = True
 
     def __init__(
         self,
@@ -69,7 +70,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator):
             python_callable=python_callable,
             op_args=op_args,
             op_kwargs=op_kwargs,
-            bash_command=NOTSET,
+            bash_command=SET_DURING_EXECUTION,
             multiple_outputs=False,
             **kwargs,
         )
@@ -83,6 +84,9 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator):
         if not isinstance(self.bash_command, str) or self.bash_command.strip() 
== "":
             raise TypeError("The returned value from the TaskFlow callable 
must be a non-empty string.")
 
+        self._is_inline_cmd = 
self._is_inline_command(bash_command=self.bash_command)
+        context["ti"].render_templates()  # type: ignore[attr-defined]
+
         return super().execute(context)
 
 
diff --git a/airflow-core/tests/unit/decorators/test_bash.py 
b/airflow-core/tests/unit/decorators/test_bash.py
index 3dfbf2cc4e9..619326a5632 100644
--- a/airflow-core/tests/unit/decorators/test_bash.py
+++ b/airflow-core/tests/unit/decorators/test_bash.py
@@ -29,10 +29,11 @@ import pytest
 from airflow.decorators import task
 from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.models.renderedtifields import RenderedTaskInstanceFields
+from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
 from airflow.utils import timezone
-from airflow.utils.types import NOTSET
 
 from tests_common.test_utils.db import clear_db_dags, clear_db_runs, 
clear_rendered_ti_fields
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
 
 if TYPE_CHECKING:
     from airflow.models import TaskInstance
@@ -69,7 +70,10 @@ class TestBashDecorator:
 
     @staticmethod
     def validate_bash_command_rtif(ti, expected_command):
-        assert 
RenderedTaskInstanceFields.get_templated_fields(ti)["bash_command"] == 
expected_command
+        if AIRFLOW_V_3_0_PLUS:
+            assert ti.task.overwrite_rtif_after_execution
+        else:
+            assert 
RenderedTaskInstanceFields.get_templated_fields(ti)["bash_command"] == 
expected_command
 
     def test_bash_decorator_init(self):
         """Test the initialization of the @task.bash decorator."""
@@ -81,13 +85,13 @@ class TestBashDecorator:
             bash_task = bash()
 
         assert bash_task.operator.task_id == "bash"
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
         assert bash_task.operator.env is None
         assert bash_task.operator.append_env is False
         assert bash_task.operator.output_encoding == "utf-8"
         assert bash_task.operator.skip_on_exit_code == [99]
         assert bash_task.operator.cwd is None
-        assert bash_task.operator._init_bash_command_not_set is True
+        assert bash_task.operator._is_inline_cmd is None
 
     @pytest.mark.parametrize(
         argnames=["command", "expected_command", "expected_return_val"],
@@ -108,13 +112,12 @@ class TestBashDecorator:
 
             bash_task = bash()
 
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
         ti, return_val = self.execute_task(bash_task)
 
         assert bash_task.operator.bash_command == expected_command
         assert return_val == expected_return_val
-
         self.validate_bash_command_rtif(ti, expected_command)
 
     def test_op_args_kwargs(self):
@@ -127,7 +130,7 @@ class TestBashDecorator:
 
             bash_task = bash("world", other_id="2")
 
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
         ti, return_val = self.execute_task(bash_task)
 
@@ -152,7 +155,7 @@ class TestBashDecorator:
 
             bash_task = bash("foo")
 
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
         ti, return_val = self.execute_task(bash_task)
 
@@ -178,7 +181,7 @@ class TestBashDecorator:
 
             bash_task = bash()
 
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
         with mock.patch.dict("os.environ", {"AIRFLOW_HOME": 
"path/to/airflow/home"}):
             ti, return_val = self.execute_task(bash_task)
@@ -207,7 +210,7 @@ class TestBashDecorator:
 
             bash_task = bash(exit_code)
 
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
         with expected:
             ti, return_val = self.execute_task(bash_task)
@@ -251,7 +254,7 @@ class TestBashDecorator:
 
             bash_task = bash(exit_code)
 
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
         with expected:
             ti, return_val = self.execute_task(bash_task)
@@ -297,7 +300,7 @@ class TestBashDecorator:
             with mock.patch.dict("os.environ", {"AIRFLOW_HOME": 
"path/to/airflow/home"}):
                 bash_task = bash(f"{cmd_file} ")
 
-                assert bash_task.operator.bash_command == NOTSET
+                assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
                 ti, return_val = self.execute_task(bash_task)
 
@@ -319,7 +322,7 @@ class TestBashDecorator:
 
             bash_task = bash()
 
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
         ti, return_val = self.execute_task(bash_task)
 
@@ -339,7 +342,7 @@ class TestBashDecorator:
 
             bash_task = bash()
 
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
         dr = self.dag_maker.create_dagrun()
         ti = dr.task_instances[0]
@@ -360,7 +363,7 @@ class TestBashDecorator:
 
             bash_task = bash()
 
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
         dr = self.dag_maker.create_dagrun()
         ti = dr.task_instances[0]
@@ -378,7 +381,7 @@ class TestBashDecorator:
 
             bash_task = bash()
 
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
         dr = self.dag_maker.create_dagrun()
         ti = dr.task_instances[0]
@@ -401,7 +404,7 @@ class TestBashDecorator:
             ):
                 bash_task = bash()
 
-                assert bash_task.operator.bash_command == NOTSET
+                assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
                 ti, _ = self.execute_task(bash_task)
 
@@ -409,12 +412,13 @@ class TestBashDecorator:
         self.validate_bash_command_rtif(ti, "echo")
 
     @pytest.mark.parametrize(
-        "multiple_outputs", [False, pytest.param(None, id="none"), 
pytest.param(NOTSET, id="not-set")]
+        "multiple_outputs",
+        [False, pytest.param(None, id="none"), 
pytest.param(SET_DURING_EXECUTION, id="not-set")],
     )
     def test_multiple_outputs(self, multiple_outputs):
         """Verify setting `multiple_outputs` for a @task.bash-decorated 
function is ignored."""
         decorator_kwargs = {}
-        if multiple_outputs is not NOTSET:
+        if multiple_outputs is not SET_DURING_EXECUTION:
             decorator_kwargs["multiple_outputs"] = multiple_outputs
 
         with self.dag:
@@ -428,7 +432,7 @@ class TestBashDecorator:
 
                 bash_task = bash()
 
-                assert bash_task.operator.bash_command == NOTSET
+                assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
                 ti, _ = self.execute_task(bash_task)
 
@@ -440,7 +444,9 @@ class TestBashDecorator:
         argvalues=[
             pytest.param(None, pytest.raises(TypeError), 
id="return_none_typeerror"),
             pytest.param(1, pytest.raises(TypeError), 
id="return_int_typeerror"),
-            pytest.param(NOTSET, pytest.raises(TypeError), 
id="return_notset_typeerror"),
+            pytest.param(
+                SET_DURING_EXECUTION, pytest.raises(TypeError), 
id="return_SET_DURING_EXECUTION_typeerror"
+            ),
             pytest.param(True, pytest.raises(TypeError), 
id="return_boolean_typeerror"),
             pytest.param("", pytest.raises(TypeError), 
id="return_empty_string_typerror"),
             pytest.param("  ", pytest.raises(TypeError), 
id="return_spaces_string_typerror"),
@@ -458,7 +464,7 @@ class TestBashDecorator:
 
             bash_task = bash()
 
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
         with expected:
             ti, _ = self.execute_task(bash_task)
@@ -475,7 +481,7 @@ class TestBashDecorator:
 
             bash_task = bash()
 
-        assert bash_task.operator.bash_command == NOTSET
+        assert bash_task.operator.bash_command == SET_DURING_EXECUTION
 
         dr = self.dag_maker.create_dagrun()
         ti = dr.task_instances[0]
diff --git 
a/providers/standard/src/airflow/providers/standard/operators/bash.py 
b/providers/standard/src/airflow/providers/standard/operators/bash.py
index 02d53737588..02a3c03afbd 100644
--- a/providers/standard/src/airflow/providers/standard/operators/bash.py
+++ b/providers/standard/src/airflow/providers/standard/operators/bash.py
@@ -28,8 +28,6 @@ from airflow.exceptions import AirflowException, 
AirflowSkipException
 from airflow.models.baseoperator import BaseOperator
 from airflow.providers.standard.hooks.subprocess import SubprocessHook, 
SubprocessResult, working_directory
 from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
-from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.types import ArgNotSet
 
 if AIRFLOW_V_3_0_PLUS:
     from airflow.sdk.execution_time.context import context_to_airflow_vars
@@ -37,7 +35,7 @@ else:
     from airflow.utils.operator_helpers import context_to_airflow_vars  # 
type: ignore[no-redef, attr-defined]
 
 if TYPE_CHECKING:
-    from sqlalchemy.orm import Session as SASession
+    from airflow.utils.types import ArgNotSet
 
     try:
         from airflow.sdk.definitions.context import Context
@@ -187,43 +185,15 @@ class BashOperator(BaseOperator):
         self.cwd = cwd
         self.append_env = append_env
         self.output_processor = output_processor
-
-        # When using the @task.bash decorator, the Bash command is not known 
until the underlying Python
-        # callable is executed and therefore set to NOTSET initially. This 
flag is useful during execution to
-        # determine whether the bash_command value needs to re-rendered.
-        self._init_bash_command_not_set = isinstance(self.bash_command, 
ArgNotSet)
-
-        # Keep a copy of the original bash_command, without the Jinja template 
rendered.
-        # This is later used to determine if the bash_command is a script or 
an inline string command.
-        # We do this later, because the bash_command is not available in 
__init__ when using @task.bash.
-        self._unrendered_bash_command: str | ArgNotSet = bash_command
+        self._is_inline_cmd = None
+        if isinstance(bash_command, str):
+            self._is_inline_cmd = 
self._is_inline_command(bash_command=bash_command)
 
     @cached_property
     def subprocess_hook(self):
         """Returns hook for running the bash command."""
         return SubprocessHook()
 
-    # TODO: This should be replaced with Task SDK API call
-    @staticmethod
-    @provide_session
-    def refresh_bash_command(ti, session: SASession = NEW_SESSION) -> None:
-        """
-        Rewrite the underlying rendered bash_command value for a task instance 
in the metadatabase.
-
-        TaskInstance.get_rendered_template_fields() cannot be used because 
this will retrieve the
-        RenderedTaskInstanceFields from the metadatabase which doesn't have 
the runtime-evaluated bash_command
-        value.
-
-        :meta private:
-        """
-        from airflow.models.renderedtifields import RenderedTaskInstanceFields
-
-        """Update rendered task instance fields for cases where runtime 
evaluated, not templated."""
-
-        rtif = RenderedTaskInstanceFields(ti)
-        RenderedTaskInstanceFields.write(rtif, session=session)
-        RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id, 
session=session)
-
     def get_env(self, context) -> dict:
         """Build the set of environment variables to be exposed for the bash 
command."""
         system_env = os.environ.copy()
@@ -252,19 +222,7 @@ class BashOperator(BaseOperator):
                 raise AirflowException(f"The cwd {self.cwd} must be a 
directory")
         env = self.get_env(context)
 
-        # Because the bash_command value is evaluated at runtime using the 
@task.bash decorator, the
-        # RenderedTaskInstanceField data needs to be rewritten and the 
bash_command value re-rendered -- the
-        # latter because the returned command from the decorated callable 
could contain a Jinja expression.
-        # Both will ensure the correct Bash command is executed and that the 
Rendered Template view in the UI
-        # displays the executed command (otherwise it will display as an 
ArgNotSet type).
-        if self._init_bash_command_not_set:
-            is_inline_command = self._is_inline_command(bash_command=cast(str, 
self.bash_command))
-            ti = context["ti"]
-            self.refresh_bash_command(ti)
-        else:
-            is_inline_command = self._is_inline_command(bash_command=cast(str, 
self._unrendered_bash_command))
-
-        if is_inline_command:
+        if self._is_inline_cmd:
             result = self._run_inline_command(bash_path=bash_path, env=env)
         else:
             result = self._run_rendered_script_file(bash_path=bash_path, 
env=env)
diff --git a/providers/standard/tests/unit/standard/operators/test_bash.py 
b/providers/standard/tests/unit/standard/operators/test_bash.py
index 59a0c8bbf23..fe33689d00e 100644
--- a/providers/standard/tests/unit/standard/operators/test_bash.py
+++ b/providers/standard/tests/unit/standard/operators/test_bash.py
@@ -60,8 +60,6 @@ class TestBashOperator:
         assert op.output_encoding == "utf-8"
         assert op.skip_on_exit_code == [99]
         assert op.cwd is None
-        assert op._init_bash_command_not_set is False
-        assert op._unrendered_bash_command == "echo"
 
     @pytest.mark.db_test
     @pytest.mark.parametrize(
diff --git a/scripts/ci/pre_commit/check_base_operator_partial_arguments.py 
b/scripts/ci/pre_commit/check_base_operator_partial_arguments.py
index e17c304a7ad..970d5d50aa7 100755
--- a/scripts/ci/pre_commit/check_base_operator_partial_arguments.py
+++ b/scripts/ci/pre_commit/check_base_operator_partial_arguments.py
@@ -43,6 +43,7 @@ IGNORED = {
     "post_execute",
     "pre_execute",
     "multiple_outputs",
+    "overwrite_rtif_after_execution",
     # Doesn't matter, not used anywhere.
     "default_args",
     # Deprecated and is aliased to max_active_tis_per_dag.
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/types.py 
b/task-sdk/src/airflow/sdk/definitions/_internal/types.py
index 0e3a39cde20..b8bd9fbfc4b 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/types.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/types.py
@@ -49,6 +49,18 @@ NOTSET = ArgNotSet()
 """Sentinel value for argument default. See ``ArgNotSet``."""
 
 
+class SetDuringExecution(ArgNotSet):
+    """Sentinel type for annotations, useful when a value is dynamic and set 
during Execution but not parsing."""
+
+    @staticmethod
+    def serialize() -> str:
+        return "DYNAMIC (set during execution)"
+
+
+SET_DURING_EXECUTION = SetDuringExecution()
+"""Sentinel value for argument default. See ``SetDuringExecution``."""
+
+
 if TYPE_CHECKING:
     import logging
 
diff --git a/task-sdk/src/airflow/sdk/definitions/baseoperator.py 
b/task-sdk/src/airflow/sdk/definitions/baseoperator.py
index fd419cd6b57..4aa19108b68 100644
--- a/task-sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -899,6 +899,11 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     # Defines if the operator supports lineage without manual definitions
     supports_lineage: bool = False
 
+    # If True, the Rendered Template fields will be overwritten in DB after 
execution
+    # This is useful for Taskflow decorators that modify the template fields 
during execution like
+    # @task.bash decorator.
+    overwrite_rtif_after_execution: bool = False
+
     # If True then the class constructor was called
     __instantiated: bool = False
     # List of args as passed to `init()`, after apply_defaults() has been 
updated. Used to "recreate" the task
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 5f23902df77..31d44f1116f 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -899,6 +899,16 @@ def finalize(
         log.debug("Setting xcom for operator extra link", link=link, 
xcom_key=xcom_key)
         _xcom_push(ti, key=xcom_key, value=link)
 
+    if getattr(ti.task, "overwrite_rtif_after_execution", False):
+        log.debug("Overwriting Rendered template fields.")
+        if ti.task.template_fields:
+            SUPERVISOR_COMMS.send_request(
+                log=log,
+                msg=SetRenderedFields(
+                    rendered_fields={field: getattr(ti.task, field) for field 
in ti.task.template_fields}
+                ),
+            )
+
     log.debug("Running finalizers", ti=ti)
     if state == TerminalTIState.SUCCESS:
         get_listener_manager().hook.on_task_instance_success(
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 96d7e2fb977..8f59eddbd11 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -1396,6 +1396,27 @@ class TestRuntimeTaskInstance:
             log=mock.ANY,
         )
 
+    def test_overwrite_rtif_after_execution_sets_rtif(self, create_runtime_ti, 
mock_supervisor_comms):
+        """Test that the RTIF is overwritten after execution for certain 
operators."""
+
+        class CustomOperator(BaseOperator):
+            overwrite_rtif_after_execution = True
+            template_fields = ["bash_command"]
+
+            def __init__(self, bash_command, *args, **kwargs):
+                self.bash_command = bash_command
+                super().__init__(*args, **kwargs)
+
+        task = CustomOperator(task_id="hello", bash_command="echo 'hi'")
+        runtime_ti = create_runtime_ti(task=task)
+
+        finalize(runtime_ti, log=mock.MagicMock(), 
state=TerminalTIState.SUCCESS)
+
+        mock_supervisor_comms.send_request.assert_called_with(
+            msg=SetRenderedFields(rendered_fields={"bash_command": "echo 
'hi'"}),
+            log=mock.ANY,
+        )
+
 
 class TestXComAfterTaskExecution:
     @pytest.mark.parametrize(

Reply via email to