kaxil commented on code in PR #44229: URL: https://github.com/apache/airflow/pull/44229#discussion_r1852434168
########## task_sdk/tests/execution_time/test_supervisor.py: ########## @@ -225,3 +225,88 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): "logger": "task", "timestamp": "2024-11-07T12:34:56.078901Z", } in captured_logs + + +class TestHandleRequest: + @pytest.fixture + def watched_subprocess(self, mocker): + """Fixture to provide a WatchedSubprocess instance.""" + return WatchedSubprocess( + ti_id=uuid7(), + pid=12345, + stdin=mocker.Mock(), # Not used in these tests + stdout=mocker.Mock(), # Not used in these tests + stderr=mocker.Mock(), # Not used in these tests + client=mocker.Mock(), + process=mocker.Mock(), + ) + + @pytest.mark.parametrize( + ["message", "expected_buffer", "client_attr_path", "method_arg", "mock_response"], + [ + pytest.param( + GetConnection(conn_id="test_conn"), + b'{"conn_id":"test_conn","conn_type":"mysql"}\n', + "connections.get", + "test_conn", + ConnectionResult(conn_id="test_conn", conn_type="mysql"), + id="get_connection", + ), + pytest.param( + GetVariable(key="test_key"), + b'{"key":"test_key","value":"test_value"}\n', + "variables.get", + "test_key", + VariableResult(key="test_key", value="test_value"), + id="get_variable", + ), + ], + ) + def test_handle_requests( + self, + watched_subprocess, + mocker, + message, + expected_buffer, + client_attr_path, + method_arg, + mock_response, + ): + """ + Test handling of different messages to the subprocess. For any new message type, add a + new parameter set to the `@pytest.mark.parametrize` decorator. + + For each message type, this test: + + 1. Sends the message to the subprocess. + 2. Verifies that the correct client method is called with the expected argument. + 3. Checks that the buffer is updated with the expected response. + """ + + def _resolve_nested_attr(obj, attr_path): + """Helper to resolve nested attributes like 'variables.get'.""" + attrs = attr_path.split(".") + for attr in attrs: + obj = getattr(obj, attr) + return obj + + # Mock the client method. E.g. `client.variables.get` or `client.connections.get` + mock_client_method = _resolve_nested_attr(watched_subprocess.client, client_attr_path) + mock_client_method.return_value = mock_response + + # Mock buffer directly as a real bytearray to avoid TypeError + buffer = bytearray() + mocker.patch("airflow.sdk.execution_time.supervisor.bytearray", return_value=buffer) Review Comment: Indeed, it was hiding a bug. ``` FAILED task_sdk/tests/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_variable] - assert equals failed b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 b'{"key":"test_key","value":"test_value"}\n' \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x 00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00{"key":"test_key","value":"test_valu e"}\n' ``` That was because msgspec truncated the buffer with to the end of serialized msg for us with [`encoder.encode_into`](https://jcristharif.com/msgspec/api.html#msgspec.json.Encoder.encode_into). So we could add `buffer[:] = b""` to clear buffer but preserving the pre-allocated memory. But @ashb and I discussed to simplify and remove buffer handling since Pydantic doesn't support buffers -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org