This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch tasksdk-call-listeners
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 7f7dc9d2c315fe42a2c9135526b1acee9f7f682a
Author: Maciej Obuchowski <obuchowski.mac...@gmail.com>
AuthorDate: Wed Jan 8 17:15:48 2025 +0100

    remove redundant dagrun model, fix edge command test
    
    Signed-off-by: Maciej Obuchowski <obuchowski.mac...@gmail.com>
---
 airflow/api_fastapi/execution_api/app.py           |  3 +-
 .../api_fastapi/execution_api/datamodels/dagrun.py | 35 ----------
 .../execution_api/routes/task_instances.py         | 15 +++-
 providers/tests/edge/cli/test_edge_command.py      |  2 +
 .../tests/edge/executors/test_edge_executor.py     |  1 +
 .../tests/openlineage/extractors/test_manager.py   |  8 ++-
 .../tests/openlineage/plugins/test_listener.py     | 81 ++++++++++++----------
 task_sdk/tests/execution_time/conftest.py          |  7 +-
 task_sdk/tests/execution_time/test_supervisor.py   |  2 +-
 task_sdk/tests/execution_time/test_task_runner.py  | 13 +++-
 10 files changed, 86 insertions(+), 81 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/app.py 
b/airflow/api_fastapi/execution_api/app.py
index 702ee9ad6a4..61283dc2cf8 100644
--- a/airflow/api_fastapi/execution_api/app.py
+++ b/airflow/api_fastapi/execution_api/app.py
@@ -77,9 +77,8 @@ def create_task_execution_api_app(app: FastAPI) -> FastAPI:
 
 def get_extra_schemas() -> dict[str, dict]:
     """Get all the extra schemas that are not part of the main FastAPI app."""
-    from airflow.api_fastapi.execution_api.datamodels import dagrun, 
taskinstance
+    from airflow.api_fastapi.execution_api.datamodels import taskinstance
 
     return {
         "TaskInstance": taskinstance.TaskInstance.model_json_schema(),
-        "DagRun": dagrun.DagRun.model_json_schema(),
     }
diff --git a/airflow/api_fastapi/execution_api/datamodels/dagrun.py 
b/airflow/api_fastapi/execution_api/datamodels/dagrun.py
deleted file mode 100644
index f9f99c1c43b..00000000000
--- a/airflow/api_fastapi/execution_api/datamodels/dagrun.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# This model is not used in the API, but it is included in generated OpenAPI 
schema
-# for use in the client SDKs.
-from __future__ import annotations
-
-from airflow.api_fastapi.common.types import UtcDateTime
-from airflow.api_fastapi.core_api.base import BaseModel
-
-
-class DagRun(BaseModel):
-    """Schema for TaskInstance model with minimal required fields needed for 
OL for now."""
-
-    id: int
-    dag_id: str
-    run_id: str
-    logical_date: UtcDateTime
-    data_interval_start: UtcDateTime
-    data_interval_end: UtcDateTime
-    clear_number: int
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index e543e9ce46f..a0b29f29e79 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -75,12 +75,23 @@ def ti_run(
     ti_id_str = str(task_instance_id)
 
     old = (
-        select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index, 
TI.next_method, TI.try_number, TI.max_tries)
+        select(
+            TI.state,
+            TI.dag_id,
+            TI.run_id,
+            TI.task_id,
+            TI.map_index,
+            TI.next_method,
+            TI.try_number,
+            TI.max_tries,
+        )
         .where(TI.id == ti_id_str)
         .with_for_update()
     )
     try:
-        (previous_state, dag_id, run_id, task_id, map_index, next_method, 
try_number, max_tries) = session.execute(old).one()
+        (previous_state, dag_id, run_id, task_id, map_index, next_method, 
try_number, max_tries) = (
+            session.execute(old).one()
+        )
     except NoResultFound:
         log.error("Task Instance %s not found", ti_id_str)
         raise HTTPException(
diff --git a/providers/tests/edge/cli/test_edge_command.py 
b/providers/tests/edge/cli/test_edge_command.py
index b1b719444ba..4bfed9d4947 100644
--- a/providers/tests/edge/cli/test_edge_command.py
+++ b/providers/tests/edge/cli/test_edge_command.py
@@ -51,6 +51,8 @@ MOCK_COMMAND = (
             "pool_slots": 1,
             "queue": "default",
             "priority_weight": 1,
+            "start_date": "2023-01-01T00:00:00+00:00",
+            "map_index": -1,
         },
         "dag_rel_path": "dummy.py",
         "log_path": "dummy.log",
diff --git a/providers/tests/edge/executors/test_edge_executor.py 
b/providers/tests/edge/executors/test_edge_executor.py
index 3a5e6b18d69..ee0b67a74cd 100644
--- a/providers/tests/edge/executors/test_edge_executor.py
+++ b/providers/tests/edge/executors/test_edge_executor.py
@@ -300,6 +300,7 @@ class TestEdgeExecutor:
                 pool_slots=1,
                 queue="default",
                 priority_weight=1,
+                start_date=timezone.utcnow(),
             ),
             dag_rel_path="dummy.py",
             log_path="dummy.log",
diff --git a/providers/tests/openlineage/extractors/test_manager.py 
b/providers/tests/openlineage/extractors/test_manager.py
index c3a56f16872..f98bce23d37 100644
--- a/providers/tests/openlineage/extractors/test_manager.py
+++ b/providers/tests/openlineage/extractors/test_manager.py
@@ -46,6 +46,7 @@ from tests_common.test_utils.version_compat import 
AIRFLOW_V_2_10_PLUS, AIRFLOW_
 
 if TYPE_CHECKING:
     from datetime import datetime
+
     try:
         from airflow.sdk.definitions.context import Context
     except ImportError:
@@ -72,7 +73,7 @@ if AIRFLOW_V_2_10_PLUS:
 
 
 if AIRFLOW_V_3_0_PLUS:
-    from airflow.sdk.api.datamodels._generated import TaskInstance as 
SDKTaskInstance
+    from airflow.sdk.api.datamodels._generated import BundleInfo, TaskInstance 
as SDKTaskInstance
     from airflow.sdk.execution_time import task_runner
     from airflow.sdk.execution_time.comms import StartupDetails
     from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, 
parse
@@ -426,6 +427,7 @@ def make_ti_context() -> MakeTIContextCallable:
                 run_type=run_type,  # type: ignore
             ),
             task_reschedule_count=task_reschedule_count,
+            max_tries=1,
         )
 
     return _make_context
@@ -446,6 +448,7 @@ def 
test_extractor_manager_gets_data_from_pythonoperator_tasksdk(
                 out.write("test")
 
     task = PythonOperator(task_id="test_task_extractor_pythonoperator", 
python_callable=use_read)
+    FAKE_BUNDLE = BundleInfo.model_construct(name="anything", version="any")
 
     what = StartupDetails(
         ti=SDKTaskInstance(
@@ -456,7 +459,8 @@ def 
test_extractor_manager_gets_data_from_pythonoperator_tasksdk(
             try_number=1,
             start_date=timezone.utcnow(),
         ),
-        file="",
+        dag_rel_path="",
+        bundle_info=FAKE_BUNDLE,
         requests_fd=0,
         ti_context=make_ti_context(),
     )
diff --git a/providers/tests/openlineage/plugins/test_listener.py 
b/providers/tests/openlineage/plugins/test_listener.py
index 10896286a7d..18d5c611c50 100644
--- a/providers/tests/openlineage/plugins/test_listener.py
+++ b/providers/tests/openlineage/plugins/test_listener.py
@@ -171,40 +171,40 @@ class TestOpenLineageListenerAirflow2:
 
         :return: TaskInstance: The created TaskInstance object.
 
-    This function creates a DAG and a PythonOperator task with the provided
-    python_callable. It generates a unique run ID and creates a DAG run. This
-    setup is useful for testing different scenarios in Airflow tasks.
+        This function creates a DAG and a PythonOperator task with the provided
+        python_callable. It generates a unique run ID and creates a DAG run. 
This
+        setup is useful for testing different scenarios in Airflow tasks.
 
-        :Example:
+            :Example:
 
-            def sample_callable(**kwargs):
-                print("Hello World")
+                def sample_callable(**kwargs):
+                    print("Hello World")
 
-        task_instance = _create_test_dag_and_task(sample_callable, 
"sample_scenario")
-        # Use task_instance to simulate running a task in a test.
-    """
-    date = dt.datetime(2022, 1, 1)
-    dag = DAG(
-        f"test_{scenario_name}",
-        schedule=None,
-        start_date=date,
-    )
-    t = PythonOperator(task_id=f"test_task_{scenario_name}", dag=dag, 
python_callable=python_callable)
-    run_id = str(uuid.uuid1())
-    dagrun_kwargs: dict = {
-        "dag_version": None,
-        "logical_date": date,
-        "triggered_by": types.DagRunTriggeredByType.TEST,
-    }
-    dagrun = dag.create_dagrun(
-        run_id=run_id,
-        data_interval=(date, date),
-        run_type=types.DagRunType.MANUAL,
-        state=DagRunState.QUEUED,
-        **dagrun_kwargs,
-    )
-    task_instance = TaskInstance(t, run_id=run_id)
-    return dagrun, task_instance
+            task_instance = _create_test_dag_and_task(sample_callable, 
"sample_scenario")
+            # Use task_instance to simulate running a task in a test.
+        """
+        date = dt.datetime(2022, 1, 1)
+        dag = DAG(
+            f"test_{scenario_name}",
+            schedule=None,
+            start_date=date,
+        )
+        t = PythonOperator(task_id=f"test_task_{scenario_name}", dag=dag, 
python_callable=python_callable)
+        run_id = str(uuid.uuid1())
+        dagrun_kwargs: dict = {
+            "dag_version": None,
+            "logical_date": date,
+            "triggered_by": types.DagRunTriggeredByType.TEST,
+        }
+        dagrun = dag.create_dagrun(
+            run_id=run_id,
+            data_interval=(date, date),
+            run_type=types.DagRunType.MANUAL,
+            state=DagRunState.QUEUED,
+            **dagrun_kwargs,
+        )
+        task_instance = TaskInstance(t, run_id=run_id)
+        return dagrun, task_instance
 
     def _create_listener_and_task_instance(self) -> tuple[OpenLineageListener, 
TaskInstance]:
         """Creates and configures an OpenLineageListener instance and a mock 
TaskInstance for testing.
@@ -712,7 +712,7 @@ class TestOpenLineageListenerAirflow3:
         dag = DAG(
             "test",
             schedule=None,
-            start_date=dt.datetime(2022, 1, 1),
+            start_date=date,
             user_defined_macros={"render_df": render_df},
             params={"df": {"col": [1, 2]}},
         )
@@ -805,18 +805,26 @@ class TestOpenLineageListenerAirflow3:
             task_instance = _create_test_dag_and_task(sample_callable, 
"sample_scenario")
             # Use task_instance to simulate running a task in a test.
         """
+        date = dt.datetime(2022, 1, 1)
         dag = DAG(
             f"test_{scenario_name}",
             schedule=None,
-            start_date=dt.datetime(2022, 1, 1),
+            start_date=date,
         )
         t = PythonOperator(task_id=f"test_task_{scenario_name}", dag=dag, 
python_callable=python_callable)
         run_id = str(uuid.uuid1())
-        triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST}
+        dagrun_kwargs = {
+            "dag_version": None,
+            "logical_date": date,
+            "triggered_by": types.DagRunTriggeredByType.TEST,
+        }
+
         dagrun = dag.create_dagrun(
-            state=State.NONE,  # type: ignore
             run_id=run_id,
-            **triggered_by_kwargs,  # type: ignore
+            data_interval=(date, date),
+            run_type=types.DagRunType.MANUAL,
+            state=DagRunState.QUEUED,
+            **dagrun_kwargs,
         )
         task_instance = TaskInstance(t, run_id=run_id)
         return dagrun, task_instance
@@ -898,6 +906,7 @@ class TestOpenLineageListenerAirflow3:
                     conf=None,
                 ),
                 task_reschedule_count=0,
+                max_tries=1,
             ),
         )
 
diff --git a/task_sdk/tests/execution_time/conftest.py 
b/task_sdk/tests/execution_time/conftest.py
index 641f14817d8..54c3b8be5e8 100644
--- a/task_sdk/tests/execution_time/conftest.py
+++ b/task_sdk/tests/execution_time/conftest.py
@@ -146,7 +146,12 @@ def create_runtime_ti(mocked_parse, make_ti_context):
 
         startup_details = StartupDetails(
             ti=TaskInstance(
-                id=ti_id, task_id=task.task_id, dag_id=dag_id, run_id=run_id, 
try_number=try_number
+                id=ti_id,
+                task_id=task.task_id,
+                dag_id=dag_id,
+                run_id=run_id,
+                try_number=try_number,
+                start_date=start_date,
             ),
             dag_rel_path="",
             bundle_info=BundleInfo.model_construct(name="anything", 
version="any"),
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 40b4445e967..b67aefcb55a 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -708,7 +708,7 @@ class TestWatchedSubprocessKill:
         ti_id = uuid7()
 
         proc = ActivitySubprocess.start(
-            path=os.devnull,
+            dag_rel_path=os.devnull,
             bundle_info=FAKE_BUNDLE,
             what=TaskInstance(
                 id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, 
start_date=tz.utcnow()
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
index 680032c5889..4797072e5be 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -613,7 +613,9 @@ def test_dag_parsing_context(make_ti_context, 
mock_supervisor_comms, monkeypatch
     task_id = "conditional_task"
 
     what = StartupDetails(
-        ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, 
run_id="c", try_number=1),
+        ti=TaskInstance(
+            id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", 
try_number=1, start_date=timezone.utcnow()
+        ),
         dag_rel_path="dag_parsing_context.py",
         bundle_info=BundleInfo(name="my-bundle", version=None),
         requests_fd=0,
@@ -653,7 +655,14 @@ class TestRuntimeTaskInstance:
         get_inline_dag(dag_id=dag_id, task=task)
 
         ti_id = uuid7()
-        ti = TaskInstance(id=ti_id, task_id=task.task_id, dag_id=dag_id, 
run_id="test_run", try_number=1)
+        ti = TaskInstance(
+            id=ti_id,
+            task_id=task.task_id,
+            dag_id=dag_id,
+            run_id="test_run",
+            try_number=1,
+            start_date=timezone.utcnow(),
+        )
 
         # Keep the context empty
         runtime_ti = RuntimeTaskInstance.model_construct(

Reply via email to