This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a6ff00f43cd Fix sql_warehouse_name resolution: handle 'warehouses' API
response key (#63286)
a6ff00f43cd is described below
commit a6ff00f43cded82e1a832cbb141e4ad946b4e519
Author: ataulmujeeb-cyber <[email protected]>
AuthorDate: Wed Mar 11 09:03:19 2026 -0400
Fix sql_warehouse_name resolution: handle 'warehouses' API response key
(#63286)
* Fix sql_warehouse_name resolution failing with "Can't list Databricks SQL
endpoints"
The _get_sql_endpoint_by_name method calls GET /api/2.0/sql/warehouses
(the current API path) but checks for the "endpoints" key in the
response. Since Databricks renamed SQL endpoints to SQL warehouses,
the current API returns data under the "warehouses" key, causing the
check to always fail.
This fix handles both the current ("warehouses") and legacy
("endpoints") response keys for backward compatibility.
Closes: #63285
* Use standard Python exceptions instead of AirflowException
Replace AirflowException with standard Python exceptions per
contributing guidelines:
- RuntimeError for unexpected API response (no warehouses/endpoints key)
- ValueError for warehouse name not found in results
---
.../providers/databricks/hooks/databricks_sql.py | 16 +++--
.../unit/databricks/hooks/test_databricks_sql.py | 84 +++++++++++++++++++++-
2 files changed, 94 insertions(+), 6 deletions(-)
diff --git
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
index 2c2164bd9c7..021142395b2 100644
---
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
+++
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
@@ -129,12 +129,20 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]:
result = self._do_api_call(LIST_SQL_ENDPOINTS_ENDPOINT)
- if "endpoints" not in result:
- raise AirflowException("Can't list Databricks SQL endpoints")
+ # The API response key depends on which endpoint path is used:
+ # - "warehouses" for the current /api/2.0/sql/warehouses path
+ # - "endpoints" for the legacy /api/2.0/sql/endpoints path
+ warehouses = result.get("warehouses") or result.get("endpoints")
+ if not warehouses:
+ raise RuntimeError(
+ "Can't list Databricks SQL warehouses. The API response
contained neither "
+ "'warehouses' nor 'endpoints' key. Check that the connection
has sufficient "
+ "permissions to list SQL warehouses."
+ )
try:
- endpoint = next(endpoint for endpoint in result["endpoints"] if
endpoint["name"] == endpoint_name)
+ endpoint = next(ep for ep in warehouses if ep["name"] ==
endpoint_name)
except StopIteration:
- raise AirflowException(f"Can't find Databricks SQL endpoint with
name '{endpoint_name}'")
+ raise ValueError(f"Can't find Databricks SQL warehouse with name
'{endpoint_name}'")
else:
return endpoint
diff --git
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
index d661f5b0714..f3f053d443c 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
@@ -80,10 +80,10 @@ def mock_get_requests():
mock_patch =
patch("airflow.providers.databricks.hooks.databricks_base.requests")
mock_requests = mock_patch.start()
- # Configure the mock object
+ # Configure the mock object with the current API response format
("warehouses" key)
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = {
- "endpoints": [
+ "warehouses": [
{
"id": "1264e5078741679a",
"name": "Test",
@@ -712,3 +712,83 @@ def test_get_df(df_type, df_class, description):
assert df.row(1)[0] == result_sets[1][0]
assert isinstance(df, df_class)
+
+
+class TestGetSqlEndpointByName:
+ """Tests for _get_sql_endpoint_by_name with both 'warehouses' and legacy
'endpoints' API response keys."""
+
+ @patch("airflow.providers.databricks.hooks.databricks_base.requests")
+ def test_resolve_warehouse_name_with_warehouses_key(self, mock_requests):
+ """Test that the current API response format with 'warehouses' key
works."""
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = {
+ "warehouses": [
+ {
+ "id": "abc123",
+ "name": "My Warehouse",
+ "odbc_params": {
+ "hostname": "xx.cloud.databricks.com",
+ "path": "/sql/1.0/warehouses/abc123",
+ },
+ }
+ ]
+ }
+ type(mock_requests.get.return_value).status_code =
PropertyMock(return_value=200)
+
+ hook = DatabricksSqlHook(sql_endpoint_name="My Warehouse")
+ endpoint = hook._get_sql_endpoint_by_name("My Warehouse")
+ assert endpoint["id"] == "abc123"
+ assert endpoint["odbc_params"]["path"] == "/sql/1.0/warehouses/abc123"
+
+ @patch("airflow.providers.databricks.hooks.databricks_base.requests")
+ def test_resolve_warehouse_name_with_legacy_endpoints_key(self,
mock_requests):
+ """Test that the legacy API response format with 'endpoints' key still
works."""
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = {
+ "endpoints": [
+ {
+ "id": "def456",
+ "name": "Legacy Endpoint",
+ "odbc_params": {
+ "hostname": "xx.cloud.databricks.com",
+ "path": "/sql/1.0/endpoints/def456",
+ },
+ }
+ ]
+ }
+ type(mock_requests.get.return_value).status_code =
PropertyMock(return_value=200)
+
+ hook = DatabricksSqlHook(sql_endpoint_name="Legacy Endpoint")
+ endpoint = hook._get_sql_endpoint_by_name("Legacy Endpoint")
+ assert endpoint["id"] == "def456"
+ assert endpoint["odbc_params"]["path"] == "/sql/1.0/endpoints/def456"
+
+ @patch("airflow.providers.databricks.hooks.databricks_base.requests")
+ def test_resolve_warehouse_name_not_found(self, mock_requests):
+ """Test that a clear error is raised when the warehouse name doesn't
match any warehouse."""
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = {
+ "warehouses": [
+ {
+ "id": "abc123",
+ "name": "Some Other Warehouse",
+ "odbc_params": {"path": "/sql/1.0/warehouses/abc123"},
+ }
+ ]
+ }
+ type(mock_requests.get.return_value).status_code =
PropertyMock(return_value=200)
+
+ hook = DatabricksSqlHook(sql_endpoint_name="Nonexistent Warehouse")
+ with pytest.raises(ValueError, match="Can't find Databricks SQL
warehouse with name"):
+ hook._get_sql_endpoint_by_name("Nonexistent Warehouse")
+
+ @patch("airflow.providers.databricks.hooks.databricks_base.requests")
+ def test_resolve_warehouse_name_empty_response(self, mock_requests):
+ """Test that a clear error is raised when the API returns no
warehouses."""
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = {}
+ type(mock_requests.get.return_value).status_code =
PropertyMock(return_value=200)
+
+ hook = DatabricksSqlHook(sql_endpoint_name="Test")
+ with pytest.raises(RuntimeError, match="Can't list Databricks SQL
warehouses"):
+ hook._get_sql_endpoint_by_name("Test")