amoghrajesh commented on code in PR #45043:
URL: https://github.com/apache/airflow/pull/45043#discussion_r1891222697


##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -689,7 +692,11 @@ def _handle_request(self, msg, log):
             self._task_end_time_monotonic = time.monotonic()
         elif isinstance(msg, GetConnection):
             conn = self.client.connections.get(msg.conn_id)
-            resp = conn.model_dump_json(exclude_unset=True).encode()
+            if isinstance(conn, ConnectionResponse):
+                conn_result = ConnectionResult.from_conn_response(conn)
+                resp = conn_result.model_dump_json(exclude_unset=True).encode()
+            elif isinstance(conn, ErrorResponse):
+                resp = conn.model_dump_json().encode()

Review Comment:
   It can just be one of the two right? `if / else` better?



##########
task_sdk/src/airflow/sdk/api/client.py:
##########
@@ -161,9 +164,19 @@ class ConnectionOperations:
     def __init__(self, client: Client):
         self.client = client
 
-    def get(self, conn_id: str) -> ConnectionResponse:
+    def get(self, conn_id: str) -> ConnectionResponse | ErrorResponse:
         """Get a connection from the API server."""
-        resp = self.client.get(f"connections/{conn_id}")
+        try:
+            resp = self.client.get(f"connections/{conn_id}")
+        except ServerResponseError as e:
+            if e.response.status_code == HTTPStatus.NOT_FOUND:
+                log.error(
+                    "Connection not found",
+                    conn_id=conn_id,
+                    detail=e.detail,
+                    status_code=e.response.status_code,
+                )
+                return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, 
detail={"conn_id": conn_id})

Review Comment:
   Yeah I agree that the api client is a better place to handle the error than 
even pushing it down to the supervisor. The supervisor's job is to delegate the 
API call to the client which will handle it and give it a response/error



##########
task_sdk/src/airflow/sdk/execution_time/comms.py:
##########
@@ -85,13 +86,26 @@ class XComResult(XComResponse):
 class ConnectionResult(ConnectionResponse):
     type: Literal["ConnectionResult"] = "ConnectionResult"
 
+    @classmethod
+    def from_conn_response(cls, connection_response: ConnectionResponse) -> 
ConnectionResult:
+        # Exclude defaults to avoid sending unnecessary data
+        # Pass the type as ConnectionResult explicitly so we can then call 
model_dump_json with exclude_unset=True
+        # to avoid sending unset fields (which are defaults in our case).
+        return cls(**connection_response.model_dump(exclude_defaults=True), 
type="ConnectionResult")
+
 
 class VariableResult(VariableResponse):
     type: Literal["VariableResult"] = "VariableResult"
 
 
+class ErrorResponse(BaseModel):
+    error: ErrorType = ErrorType.GENERIC_ERROR
+    detail: dict | None = None
+    type: Literal["ErrorResponse"] = "ErrorResponse"

Review Comment:
   I like this idea, having an error type makes it very nice to extend for 
future cases too



##########
task_sdk/tests/execution_time/test_supervisor.py:
##########
@@ -764,7 +764,7 @@ def watched_subprocess(self, mocker):
         [
             pytest.param(
                 GetConnection(conn_id="test_conn"),
-                b'{"conn_id":"test_conn","conn_type":"mysql"}\n',
+                
b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n',

Review Comment:
   Maybe I am getting it wrong, but isn't the purpose of this test to just 
check if a certain kind of "message" can be handled by the supervisor? 
supervisor -> client communication essentially. We never touch any part of task 
runner iiuc. 



##########
task_sdk/tests/execution_time/test_context.py:
##########
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest import mock
+
+from airflow.sdk.definitions.connection import Connection
+from airflow.sdk.exceptions import ErrorType
+from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse
+from airflow.sdk.execution_time.context import ConnectionAccessor, 
_convert_connection_result_conn
+
+
+def test_convert_connection_result_conn():
+    """Test that the ConnectionResult is converted to a Connection object."""
+    conn = ConnectionResult(
+        conn_id="test_conn",
+        conn_type="mysql",
+        host="mysql",
+        schema="airflow",
+        login="root",
+        password="password",
+        port=1234,
+        extra='{"extra_key": "extra_value"}',
+    )
+    conn = _convert_connection_result_conn(conn)
+    assert conn == Connection(
+        conn_id="test_conn",
+        conn_type="mysql",
+        host="mysql",
+        schema="airflow",
+        login="root",
+        password="password",
+        port=1234,
+        extra='{"extra_key": "extra_value"}',
+    )
+
+
+class TestConnectionAccessor:
+    def test_getattr_connection(self):
+        """
+        Test that the connection is fetched when accessed via __getattr__.
+
+        The __getattr__ method is used for template rendering. Example: ``{{ 
conn.mysql_conn.host }}``.
+        """
+        accessor = ConnectionAccessor()
+
+        # Conn from the supervisor / API Server
+        conn_result = ConnectionResult(conn_id="mysql_conn", 
conn_type="mysql", host="mysql", port=3306)
+
+        with mock.patch(
+            "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", 
create=True
+        ) as mock_supervisor_comms:
+            mock_supervisor_comms.get_message.return_value = conn_result
+
+            # Fetch the connection; Triggers __getattr__
+            conn = accessor.mysql_conn

Review Comment:
   ```suggestion
               # Fetch the connection; triggers __getattr__
               conn = accessor.mysql_conn
   ```



##########
task_sdk/src/airflow/sdk/execution_time/context.py:
##########
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+import structlog
+
+if TYPE_CHECKING:
+    from airflow.sdk.execution_time.comms import ConnectionResult
+
+
+def _convert_connection_result_conn(conn_result: ConnectionResult):
+    from airflow.sdk.definitions.connection import Connection
+
+    return Connection(**conn_result.model_dump(exclude={"type"}, 
by_alias=True))

Review Comment:
   Thanks!



##########
task_sdk/src/airflow/sdk/execution_time/context.py:
##########
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+import structlog
+
+from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
+
+if TYPE_CHECKING:
+    from airflow.sdk.definitions.connection import Connection
+    from airflow.sdk.execution_time.comms import ConnectionResult
+
+
+def _convert_connection_result_conn(conn_result: ConnectionResult):
+    from airflow.sdk.definitions.connection import Connection
+
+    # `by_alias=True` is used to convert the `schema` field to `schema_` in 
the Connection model
+    return Connection(**conn_result.model_dump(exclude={"type"}, 
by_alias=True))
+
+
+def _get_connection(conn_id: str) -> Connection:
+    # TODO: This should probably be moved to a separate module like 
`airflow.sdk.execution_time.comms`
+    #   or `airflow.sdk.execution_time.connection`
+    #   A reason to not move it to `airflow.sdk.execution_time.comms` is that 
it
+    #   will make that module depend on Task SDK, which is not ideal because 
we intend to
+    #   keep Task SDK as a separate package than execution time mods.
+    from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection
+    from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+    log = structlog.get_logger(logger_name="task")
+    SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id))
+    msg = SUPERVISOR_COMMS.get_message()
+    if isinstance(msg, ErrorResponse):
+        raise AirflowRuntimeError(msg)
+
+    if TYPE_CHECKING:
+        assert isinstance(msg, ConnectionResult)
+    return _convert_connection_result_conn(msg)
+
+
+class ConnectionAccessor:
+    """Wrapper to access Connection entries in template."""
+
+    def __getattr__(self, conn_id: str) -> Any:
+        return _get_connection(conn_id)
+
+    def __repr__(self) -> str:
+        return "<ConnectionAccessor (dynamic access)>"

Review Comment:
   Do we need to expand this one better or leave as is?



##########
task_sdk/tests/execution_time/test_context.py:
##########
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest import mock
+
+from airflow.sdk.definitions.connection import Connection
+from airflow.sdk.exceptions import ErrorType
+from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse
+from airflow.sdk.execution_time.context import ConnectionAccessor, 
_convert_connection_result_conn
+
+
+def test_convert_connection_result_conn():
+    """Test that the ConnectionResult is converted to a Connection object."""
+    conn = ConnectionResult(
+        conn_id="test_conn",
+        conn_type="mysql",
+        host="mysql",
+        schema="airflow",
+        login="root",
+        password="password",
+        port=1234,
+        extra='{"extra_key": "extra_value"}',
+    )
+    conn = _convert_connection_result_conn(conn)
+    assert conn == Connection(
+        conn_id="test_conn",
+        conn_type="mysql",
+        host="mysql",
+        schema="airflow",
+        login="root",
+        password="password",
+        port=1234,
+        extra='{"extra_key": "extra_value"}',
+    )
+
+
+class TestConnectionAccessor:
+    def test_getattr_connection(self):
+        """
+        Test that the connection is fetched when accessed via __getattr__.
+
+        The __getattr__ method is used for template rendering. Example: ``{{ 
conn.mysql_conn.host }}``.
+        """
+        accessor = ConnectionAccessor()
+
+        # Conn from the supervisor / API Server
+        conn_result = ConnectionResult(conn_id="mysql_conn", 
conn_type="mysql", host="mysql", port=3306)
+
+        with mock.patch(
+            "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", 
create=True
+        ) as mock_supervisor_comms:
+            mock_supervisor_comms.get_message.return_value = conn_result
+
+            # Fetch the connection; Triggers __getattr__
+            conn = accessor.mysql_conn

Review Comment:
   assert if __getattr__ was called?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to