Copilot commented on code in PR #63889:
URL: https://github.com/apache/airflow/pull/63889#discussion_r3025344261


##########
airflow-ctl/src/airflowctl/ctl/cli_config.py:
##########
@@ -645,10 +655,21 @@ def _get_func(args: Namespace, api_operation: dict, 
api_client: Client = NEW_API
                                 datamodel_param_name = parameter_key
                             if expanded_parameter in self.excluded_parameters:
                                 continue
-                            if expanded_parameter in args_dict.keys():
+                            if (
+                                expanded_parameter in args_dict.keys()
+                                and args_dict[expanded_parameter] is not None
+                            ):
+                                val = args_dict[expanded_parameter]
+                                # Automatically convert comma-separated 
strings to lists if the field expects a list
+                                field_annotation = str(
+                                    
datamodel.model_fields[expanded_parameter].annotation
+                                ).lower()
+                                if "list" in field_annotation and 
isinstance(val, str):
+                                    val = [v.strip() for v in val.split(",") 
if v.strip()]

Review Comment:
   `datamodel.model_fields[expanded_parameter]` can raise `KeyError` when a CLI 
arg exists in `args_dict` but the Pydantic model uses a different field 
name/alias (or the arg is for a non-model parameter). Also, detecting 
“list-ness” via `str(annotation).lower()` + substring `"list"` is brittle and 
can misclassify types. Prefer guarding with `if expanded_parameter in 
datamodel.model_fields`, and use `typing.get_origin()` / `typing.get_args()` 
(or Pydantic field metadata) to reliably detect `list[...]` and 
`Optional[list[...]]`.



##########
airflow-ctl/src/airflowctl/api/operations.py:
##########
@@ -893,3 +897,47 @@ def delete(
             return key
         except ServerResponseError as e:
             raise e
+
+
+class TaskInstanceOperations(BaseOperations):
+    """Task instance operations."""
+
+    def get(self, dag_id: str, dag_run_id: str, task_id: str) -> Any:
+        """Get a task instance."""
+        self.response = 
self.client.get(f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}")
+        data = self.response.json()
+        if isinstance(data, list):
+            return [TaskInstanceResponse.model_validate(item) for item in data]
+        if "task_instances" in data:
+            return TaskInstanceCollectionResponse.model_validate(data)
+        return TaskInstanceResponse.model_validate(data)

Review Comment:
   The method docstring says “Get a task instance”, but the implementation can 
return three different shapes (`list[TaskInstanceResponse]`, 
`TaskInstanceCollectionResponse`, or `TaskInstanceResponse`) while the 
signature is `-> Any`. This makes it harder for CLI/rendering code to behave 
predictably. Consider normalizing to a single return type for `get` (e.g., 
always return `TaskInstanceResponse`, unwrapping the collection/list when 
needed) or explicitly annotate a narrow union return type and ensure downstream 
formatting can handle it.



##########
airflow-ctl/tests/airflow_ctl/api/test_operations.py:
##########
@@ -1732,3 +1737,166 @@ def handle_request(request: httpx.Request) -> 
httpx.Response:
             map_index=self.map_index,
         )
         assert response == self.key
+
+
+class TestTaskInstanceOperations:
+    """Test suite for Task Instance operations."""
+
+    dag_id: str = "test_dag"
+    dag_run_id: str = "manual__2025-01-24T00:00:00+00:00"
+    task_id: str = "test_task"
+
+    task_instance_response = TaskInstanceResponse(
+        id=uuid.uuid4(),
+        task_id=task_id,
+        dag_id=dag_id,
+        dag_run_id=dag_run_id,
+        map_index=-1,
+        logical_date=datetime.datetime(2025, 1, 24, 0, 0, 0),
+        run_after=datetime.datetime(2025, 1, 24, 0, 0, 0),
+        start_date=datetime.datetime(2025, 1, 24, 0, 0, 1),
+        end_date=datetime.datetime(2025, 1, 24, 0, 0, 10),
+        duration=9.0,
+        state=TaskInstanceState.SUCCESS,
+        try_number=1,
+        max_tries=0,
+        task_display_name=task_id,
+        dag_display_name=dag_id,
+        hostname="hostname",
+        unixname="airflow",
+        pool="default_pool",
+        pool_slots=1,
+        queue="default",
+        priority_weight=1,
+        operator="EmptyOperator",
+        executor_config="{}",
+        note=None,
+    )
+
+    task_instance_collection_response = TaskInstanceCollectionResponse(
+        task_instances=[task_instance_response],
+        total_entries=1,
+    )
+
+    def test_get(self):
+        """Test fetching a single task instance."""
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            assert request.url.path == (
+                
f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_id}"
+            )
+            return httpx.Response(200, 
json=json.loads(self.task_instance_response.model_dump_json()))

Review Comment:
   `TaskInstanceOperations.get()` and `.update()` include branches to handle 
responses shaped as a raw list and as a `{"task_instances": ...}` collection 
wrapper, but the unit tests only cover the “single object” response path. Add 
tests that return (1) a JSON list payload and (2) a 
`TaskInstanceCollectionResponse` payload for `get`/`update` to ensure those 
defensive branches behave as intended.



##########
airflow-ctl/tests/airflow_ctl/api/test_operations.py:
##########
@@ -1732,3 +1737,166 @@ def handle_request(request: httpx.Request) -> 
httpx.Response:
             map_index=self.map_index,
         )
         assert response == self.key
+
+
+class TestTaskInstanceOperations:
+    """Test suite for Task Instance operations."""
+
+    dag_id: str = "test_dag"
+    dag_run_id: str = "manual__2025-01-24T00:00:00+00:00"
+    task_id: str = "test_task"
+
+    task_instance_response = TaskInstanceResponse(
+        id=uuid.uuid4(),
+        task_id=task_id,
+        dag_id=dag_id,
+        dag_run_id=dag_run_id,
+        map_index=-1,
+        logical_date=datetime.datetime(2025, 1, 24, 0, 0, 0),
+        run_after=datetime.datetime(2025, 1, 24, 0, 0, 0),
+        start_date=datetime.datetime(2025, 1, 24, 0, 0, 1),
+        end_date=datetime.datetime(2025, 1, 24, 0, 0, 10),
+        duration=9.0,
+        state=TaskInstanceState.SUCCESS,
+        try_number=1,
+        max_tries=0,
+        task_display_name=task_id,
+        dag_display_name=dag_id,
+        hostname="hostname",
+        unixname="airflow",
+        pool="default_pool",
+        pool_slots=1,
+        queue="default",
+        priority_weight=1,
+        operator="EmptyOperator",
+        executor_config="{}",
+        note=None,
+    )
+
+    task_instance_collection_response = TaskInstanceCollectionResponse(
+        task_instances=[task_instance_response],
+        total_entries=1,
+    )
+
+    def test_get(self):
+        """Test fetching a single task instance."""
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            assert request.url.path == (
+                
f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_id}"
+            )
+            return httpx.Response(200, 
json=json.loads(self.task_instance_response.model_dump_json()))
+
+        client = make_api_client(transport=httpx.MockTransport(handle_request))
+        response = client.task_instances.get(
+            dag_id=self.dag_id,
+            dag_run_id=self.dag_run_id,
+            task_id=self.task_id,
+        )
+        assert response == self.task_instance_response
+
+    def test_list(self):
+        """Test listing task instances for a DAG run."""
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            assert request.url.path == 
(f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances")
+            return httpx.Response(
+                200, 
json=json.loads(self.task_instance_collection_response.model_dump_json())
+            )
+
+        client = make_api_client(transport=httpx.MockTransport(handle_request))
+        response = client.task_instances.list(
+            dag_id=self.dag_id,
+            dag_run_id=self.dag_run_id,
+        )
+        assert response == self.task_instance_collection_response
+
+    def test_clear(self):
+        """Test clearing task instances with default options."""
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            assert request.url.path == 
f"/api/v2/dags/{self.dag_id}/clearTaskInstances"
+            request_body = json.loads(request.content)
+            assert request_body["dry_run"] is True
+            return httpx.Response(
+                200, 
json=json.loads(self.task_instance_collection_response.model_dump_json())
+            )
+
+        client = make_api_client(transport=httpx.MockTransport(handle_request))
+        body = ClearTaskInstancesBody(dry_run=True)
+        response = client.task_instances.clear(
+            dag_id=self.dag_id,
+            body=body,
+        )
+        assert response == self.task_instance_collection_response
+
+    def test_clear_with_options(self):
+        """Test clearing task instances with specific options."""
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            assert request.url.path == 
f"/api/v2/dags/{self.dag_id}/clearTaskInstances"
+            request_body = json.loads(request.content)
+            assert request_body["dry_run"] is False
+            assert request_body["only_failed"] is True
+            assert request_body["task_ids"] == [self.task_id]
+            assert request_body["dag_run_id"] == self.dag_run_id
+            return httpx.Response(
+                200, 
json=json.loads(self.task_instance_collection_response.model_dump_json())
+            )
+
+        client = make_api_client(transport=httpx.MockTransport(handle_request))
+        body = ClearTaskInstancesBody(
+            dry_run=False,
+            only_failed=True,
+            task_ids=[self.task_id],
+            dag_run_id=self.dag_run_id,
+        )
+        response = client.task_instances.clear(
+            dag_id=self.dag_id,
+            body=body,
+        )
+        assert response == self.task_instance_collection_response
+
+    def test_update(self):
+        """Test updating a task instance state."""
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            assert request.url.path == (
+                
f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_id}"
+            )
+            request_body = json.loads(request.content)
+            assert request_body["new_state"] == TaskInstanceState.FAILED.value
+            return httpx.Response(200, 
json=json.loads(self.task_instance_response.model_dump_json()))

Review Comment:
   `TaskInstanceOperations.get()` and `.update()` include branches to handle 
responses shaped as a raw list and as a `{"task_instances": ...}` collection 
wrapper, but the unit tests only cover the “single object” response path. Add 
tests that return (1) a JSON list payload and (2) a 
`TaskInstanceCollectionResponse` payload for `get`/`update` to ensure those 
defensive branches behave as intended.



##########
airflow-ctl/tests/airflow_ctl/ctl/commands/test_pool_command.py:
##########
@@ -174,8 +174,11 @@ def test_export_json_to_file(self, mock_client, tmp_path, 
capsys):
 
         # Verify output message
         captured = capsys.readouterr()
-        expected_output = f"Exported {len(exported_data)} pool(s) to 
{export_file}"
-        assert expected_output in captured.out.replace("\n", "")
+        out_str = captured.out.replace("\n", "")
+        # The output contains rich ANSI codes, so we check for key substrings 
instead
+        assert "Exported" in out_str
+        assert str(len(exported_data)) in out_str
+        assert "pool" in out_str

Review Comment:
   These assertions became very loose (any output containing those substrings 
will pass), which can hide regressions in the actual message/format. Since the 
motivation is ANSI styling, a more robust approach is to strip ANSI escape 
sequences first and then assert the full expected message (or a stricter, 
structured subset like `f"Exported {len(exported_data)} pool(s) to 
{export_file}"`).



##########
airflow-ctl/src/airflowctl/api/operations.py:
##########
@@ -893,3 +897,47 @@ def delete(
             return key
         except ServerResponseError as e:
             raise e
+
+
+class TaskInstanceOperations(BaseOperations):
+    """Task instance operations."""
+
+    def get(self, dag_id: str, dag_run_id: str, task_id: str) -> Any:
+        """Get a task instance."""
+        self.response = 
self.client.get(f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}")
+        data = self.response.json()
+        if isinstance(data, list):
+            return [TaskInstanceResponse.model_validate(item) for item in data]
+        if "task_instances" in data:
+            return TaskInstanceCollectionResponse.model_validate(data)
+        return TaskInstanceResponse.model_validate(data)
+
+    def list(self, dag_id: str, dag_run_id: str) -> 
TaskInstanceCollectionResponse | ServerResponseError:
+        """List task instances."""
+        return super().execute_list(
+            path=f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances",
+            data_model=TaskInstanceCollectionResponse,
+        )
+
+    def clear(
+        self, dag_id: str, body: ClearTaskInstancesBody
+    ) -> TaskInstanceCollectionResponse | ServerResponseError:
+        """Clear task instances."""
+        self.response = self.client.post(
+            f"dags/{dag_id}/clearTaskInstances",
+            json=body.model_dump(mode="json", exclude_unset=True),
+        )
+        return 
TaskInstanceCollectionResponse.model_validate_json(self.response.content)
+
+    def update(self, dag_id: str, dag_run_id: str, task_id: str, body: 
PatchTaskInstanceBody) -> Any:
+        """Update a task instance."""
+        self.response = self.client.patch(
+            f"dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}",
+            json=body.model_dump(mode="json", exclude_unset=True),
+        )
+        data = self.response.json()
+        if isinstance(data, list):
+            return [TaskInstanceResponse.model_validate(item) for item in data]
+        if "task_instances" in data:
+            return TaskInstanceCollectionResponse.model_validate(data)
+        return TaskInstanceResponse.model_validate(data)

Review Comment:
   The response-shape normalization logic is duplicated between `get()` and 
`update()`. Pulling this into a private helper (e.g., 
`_parse_task_instance_response(data)`) would reduce repetition and make it 
easier to keep behavior consistent if Airflow’s API responses change.



##########
airflow-ctl/tests/airflow_ctl/api/test_operations.py:
##########
@@ -1732,3 +1737,166 @@ def handle_request(request: httpx.Request) -> 
httpx.Response:
             map_index=self.map_index,
         )
         assert response == self.key
+
+
+class TestTaskInstanceOperations:
+    """Test suite for Task Instance operations."""
+
+    dag_id: str = "test_dag"
+    dag_run_id: str = "manual__2025-01-24T00:00:00+00:00"
+    task_id: str = "test_task"
+
+    task_instance_response = TaskInstanceResponse(
+        id=uuid.uuid4(),
+        task_id=task_id,
+        dag_id=dag_id,
+        dag_run_id=dag_run_id,
+        map_index=-1,
+        logical_date=datetime.datetime(2025, 1, 24, 0, 0, 0),
+        run_after=datetime.datetime(2025, 1, 24, 0, 0, 0),
+        start_date=datetime.datetime(2025, 1, 24, 0, 0, 1),
+        end_date=datetime.datetime(2025, 1, 24, 0, 0, 10),
+        duration=9.0,
+        state=TaskInstanceState.SUCCESS,
+        try_number=1,
+        max_tries=0,
+        task_display_name=task_id,
+        dag_display_name=dag_id,
+        hostname="hostname",
+        unixname="airflow",
+        pool="default_pool",
+        pool_slots=1,
+        queue="default",
+        priority_weight=1,
+        operator="EmptyOperator",
+        executor_config="{}",
+        note=None,
+    )
+
+    task_instance_collection_response = TaskInstanceCollectionResponse(
+        task_instances=[task_instance_response],
+        total_entries=1,
+    )
+
+    def test_get(self):
+        """Test fetching a single task instance."""
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            assert request.url.path == (
+                
f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances/{self.task_id}"
+            )
+            return httpx.Response(200, 
json=json.loads(self.task_instance_response.model_dump_json()))
+
+        client = make_api_client(transport=httpx.MockTransport(handle_request))
+        response = client.task_instances.get(
+            dag_id=self.dag_id,
+            dag_run_id=self.dag_run_id,
+            task_id=self.task_id,
+        )
+        assert response == self.task_instance_response
+
+    def test_list(self):
+        """Test listing task instances for a DAG run."""
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            assert request.url.path == 
(f"/api/v2/dags/{self.dag_id}/dagRuns/{self.dag_run_id}/taskInstances")
+            return httpx.Response(
+                200, 
json=json.loads(self.task_instance_collection_response.model_dump_json())
+            )
+
+        client = make_api_client(transport=httpx.MockTransport(handle_request))
+        response = client.task_instances.list(
+            dag_id=self.dag_id,
+            dag_run_id=self.dag_run_id,
+        )
+        assert response == self.task_instance_collection_response
+
+    def test_clear(self):
+        """Test clearing task instances with default options."""
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            assert request.url.path == 
f"/api/v2/dags/{self.dag_id}/clearTaskInstances"
+            request_body = json.loads(request.content)
+            assert request_body["dry_run"] is True
+            return httpx.Response(
+                200, 
json=json.loads(self.task_instance_collection_response.model_dump_json())
+            )
+
+        client = make_api_client(transport=httpx.MockTransport(handle_request))
+        body = ClearTaskInstancesBody(dry_run=True)
+        response = client.task_instances.clear(
+            dag_id=self.dag_id,
+            body=body,
+        )
+        assert response == self.task_instance_collection_response
+
+    def test_clear_with_options(self):
+        """Test clearing task instances with specific options."""
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            assert request.url.path == 
f"/api/v2/dags/{self.dag_id}/clearTaskInstances"
+            request_body = json.loads(request.content)
+            assert request_body["dry_run"] is False
+            assert request_body["only_failed"] is True
+            assert request_body["task_ids"] == [self.task_id]
+            assert request_body["dag_run_id"] == self.dag_run_id
+            return httpx.Response(
+                200, 
json=json.loads(self.task_instance_collection_response.model_dump_json())
+            )
+
+        client = make_api_client(transport=httpx.MockTransport(handle_request))
+        body = ClearTaskInstancesBody(
+            dry_run=False,
+            only_failed=True,
+            task_ids=[self.task_id],
+            dag_run_id=self.dag_run_id,
+        )
+        response = client.task_instances.clear(
+            dag_id=self.dag_id,
+            body=body,
+        )
+        assert response == self.task_instance_collection_response
+
+    def test_update(self):
+        """Test updating a task instance state."""

Review Comment:
   `TaskInstanceOperations.get()` and `.update()` include branches to handle 
responses shaped as a raw list and as a `{"task_instances": ...}` collection 
wrapper, but the unit tests only cover the “single object” response path. Add 
tests that return (1) a JSON list payload and (2) a 
`TaskInstanceCollectionResponse` payload for `get`/`update` to ensure those 
defensive branches behave as intended.



##########
airflow-ctl/tests/airflow_ctl/api/test_operations.py:
##########
@@ -1732,3 +1737,166 @@ def handle_request(request: httpx.Request) -> 
httpx.Response:
             map_index=self.map_index,
         )
         assert response == self.key
+
+
+class TestTaskInstanceOperations:
+    """Test suite for Task Instance operations."""
+
+    dag_id: str = "test_dag"
+    dag_run_id: str = "manual__2025-01-24T00:00:00+00:00"
+    task_id: str = "test_task"
+
+    task_instance_response = TaskInstanceResponse(
+        id=uuid.uuid4(),
+        task_id=task_id,
+        dag_id=dag_id,
+        dag_run_id=dag_run_id,
+        map_index=-1,
+        logical_date=datetime.datetime(2025, 1, 24, 0, 0, 0),
+        run_after=datetime.datetime(2025, 1, 24, 0, 0, 0),
+        start_date=datetime.datetime(2025, 1, 24, 0, 0, 1),
+        end_date=datetime.datetime(2025, 1, 24, 0, 0, 10),
+        duration=9.0,
+        state=TaskInstanceState.SUCCESS,
+        try_number=1,
+        max_tries=0,
+        task_display_name=task_id,
+        dag_display_name=dag_id,
+        hostname="hostname",
+        unixname="airflow",
+        pool="default_pool",
+        pool_slots=1,
+        queue="default",
+        priority_weight=1,
+        operator="EmptyOperator",
+        executor_config="{}",
+        note=None,
+    )
+
+    task_instance_collection_response = TaskInstanceCollectionResponse(
+        task_instances=[task_instance_response],
+        total_entries=1,
+    )
+
+    def test_get(self):
+        """Test fetching a single task instance."""

Review Comment:
   `TaskInstanceOperations.get()` and `.update()` include branches to handle 
responses shaped as a raw list and as a `{"task_instances": ...}` collection 
wrapper, but the unit tests only cover the “single object” response path. Add 
tests that return (1) a JSON list payload and (2) a 
`TaskInstanceCollectionResponse` payload for `get`/`update` to ensure those 
defensive branches behave as intended.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to