anishgirianish commented on code in PR #62343:
URL: https://github.com/apache/airflow/pull/62343#discussion_r3177849992
##########
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py:
##########
@@ -1222,6 +1231,184 @@ def
test_should_test_new_connection_without_existing(self, test_client):
assert response.json()["status"] is True
+class TestAsyncConnectionTest(TestConnectionEndpoint):
+ """Tests for the async connection test endpoints (POST + GET polling)."""
+
+ TEST_REQUEST_BODY = {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "host": TEST_CONN_HOST,
+ }
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_should_respond_202(self, test_client, session):
+ """POST /connections/test-async returns 202 + token."""
+ response = test_client.post("/connections/test-async",
json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 202
+ body = response.json()
+ assert "token" in body
+ assert body["connection_id"] == TEST_CONN_ID
+ assert body["state"] == "pending"
+ assert len(body["token"]) > 0
+
+ def test_should_respond_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.post("/connections/test-async",
json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 401
+
+ def test_should_respond_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.post("/connections/test-async",
json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 403
+
+ def test_should_respond_403_by_default(self, test_client):
+ """Connection testing is disabled by default."""
+ response = test_client.post("/connections/test-async",
json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 403
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_creates_connection_test_request_row(self, test_client,
session):
+ """POST creates a ConnectionTestRequest row in PENDING state with
connection fields."""
+ response = test_client.post("/connections/test-async",
json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 202
+ token = response.json()["token"]
+
+ ct =
session.scalar(select(ConnectionTestRequest).filter_by(token=token))
+ assert ct is not None
+ assert ct.connection_id == TEST_CONN_ID
+ assert ct.conn_type == TEST_CONN_TYPE
+ assert ct.host == TEST_CONN_HOST
+ assert ct.state == "pending"
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_passes_queue_parameter(self, test_client, session):
+ """POST /connections/test-async passes the queue parameter."""
+ body = {**self.TEST_REQUEST_BODY, "queue": "gpu_workers"}
+ response = test_client.post("/connections/test-async", json=body)
+ assert response.status_code == 202
+ token = response.json()["token"]
+
+ ct =
session.scalar(select(ConnectionTestRequest).filter_by(token=token))
+ assert ct is not None
+ assert ct.queue == "gpu_workers"
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_stores_commit_on_success(self, test_client, session):
+ """POST /connections/test-async stores the commit_on_success flag."""
+ body = {**self.TEST_REQUEST_BODY, "commit_on_success": True}
+ response = test_client.post("/connections/test-async", json=body)
+ assert response.status_code == 202
+ token = response.json()["token"]
+
+ ct =
session.scalar(select(ConnectionTestRequest).filter_by(token=token))
+ assert ct is not None
+ assert ct.commit_on_success is True
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_returns_409_for_duplicate_active_test(self, test_client,
session):
+ """POST returns 409 when there's already an active test for the same
connection_id."""
+ response = test_client.post("/connections/test-async",
json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 202
+
+ response = test_client.post("/connections/test-async",
json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 409
+ assert response.json()["detail"]["reason"] == "Unique constraint
violation"
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_rejects_unknown_executor_with_422(self, test_client,
session):
+ """POST returns 422 when the requested executor is not configured."""
+ body = {**self.TEST_REQUEST_BODY, "executor": "no_such_executor"}
+ response = test_client.post("/connections/test-async", json=body)
+ assert response.status_code == 422
+ assert "no_such_executor" in response.text
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_accepts_configured_executor(self, test_client, session):
+ """POST accepts an executor name that matches a configured executor."""
+ configured = ExecutorLoader.get_executor_names(validate_teams=False)
+ executor_name = configured[0].alias or configured[0].module_path
+ body = {**self.TEST_REQUEST_BODY, "executor": executor_name}
+ response = test_client.post("/connections/test-async", json=body)
+ assert response.status_code == 202
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_get_status_returns_pending(self, test_client, session):
+ """GET /connections/test-async/{token} returns current status."""
+ post_response = test_client.post("/connections/test-async",
json=self.TEST_REQUEST_BODY)
+ token = post_response.json()["token"]
+
+ response = test_client.get(f"/connections/test-async/{token}")
+ assert response.status_code == 200
+ body = response.json()
+ assert body["token"] == token
+ assert body["connection_id"] == TEST_CONN_ID
+ assert body["state"] == "pending"
+ assert body["result_message"] is None
+ assert "created_at" in body
+ assert "reverted" not in body
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_get_status_returns_completed_result(self, test_client, session):
+ """GET returns result after the worker has updated the test."""
+ post_response = test_client.post("/connections/test-async",
json=self.TEST_REQUEST_BODY)
+ token = post_response.json()["token"]
+
+ ct =
session.scalar(select(ConnectionTestRequest).filter_by(token=token))
+ ct.state = ConnectionTestState.SUCCESS
+ ct.result_message = "Connection successfully tested"
+ session.commit()
+
+ response = test_client.get(f"/connections/test-async/{token}")
+ assert response.status_code == 200
+ body = response.json()
+ assert body["state"] == "success"
+ assert body["result_message"] == "Connection successfully tested"
+
+ def test_get_status_returns_404_for_invalid_token(self, test_client):
Review Comment:
Added, test_get_status_unauthorized_user_does_not_leak_row thanks
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]