This is an automated email from the ASF dual-hosted git repository.

yzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/polaris-tools.git


The following commit(s) were added to refs/heads/main by this push:
     new 19c56b7  [MCP] feat: Add realm support (#83)
19c56b7 is described below

commit 19c56b70903c04df995f6f8b0441997c13380d95
Author: Yong Zheng <[email protected]>
AuthorDate: Wed Dec 3 19:07:02 2025 -0600

    [MCP] feat: Add realm support (#83)
---
 mcp-server/README.md                           |   7 +-
 mcp-server/polaris_mcp/authorization.py        | 116 +++++++++-----
 mcp-server/polaris_mcp/rest.py                 |   9 +-
 mcp-server/polaris_mcp/server.py               |  39 +++--
 mcp-server/polaris_mcp/tools/catalog.py        |   4 +
 mcp-server/polaris_mcp/tools/catalog_role.py   |   4 +
 mcp-server/polaris_mcp/tools/namespace.py      |   4 +
 mcp-server/polaris_mcp/tools/policy.py         |   4 +
 mcp-server/polaris_mcp/tools/principal.py      |   4 +
 mcp-server/polaris_mcp/tools/principal_role.py |   4 +
 mcp-server/polaris_mcp/tools/table.py          |   4 +
 mcp-server/tests/test_authorization.py         | 210 ++++++++++++++++++++++---
 mcp-server/tests/test_rest_tool.py             |  65 ++++++++
 mcp-server/tests/test_server.py                |   5 +-
 14 files changed, 400 insertions(+), 79 deletions(-)

diff --git a/mcp-server/README.md b/mcp-server/README.md
index 98f7998..a6c5fd4 100644
--- a/mcp-server/README.md
+++ b/mcp-server/README.md
@@ -157,6 +157,11 @@ uv run client.py http://localhost:8000/mcp \
 | `POLARIS_CLIENT_SECRET`                                        | OAuth 
client secret.                                             | _unset_            
                              |
 | `POLARIS_TOKEN_SCOPE`                                          | OAuth scope 
string.                                              | _unset_                  
                        |
 | `POLARIS_TOKEN_URL`                                            | Optional 
override for the token endpoint URL.                    | 
`${POLARIS_BASE_URL}api/catalog/v1/oauth/tokens` |
+| `POLARIS_REALM_{realm}_CLIENT_ID`                              | OAuth 
client id for a specific realm.                            | _unset_            
                              |
+| `POLARIS_REALM_{realm}_CLIENT_SECRET`                          | OAuth 
client secret for a specific realm.                        | _unset_            
                              |
+| `POLARIS_REALM_{realm}_TOKEN_SCOPE`                            | OAuth scope 
for a specific realm.                                | _unset_                  
                        |
+| `POLARIS_REALM_{realm}_TOKEN_URL`                              | Token 
endpoint URL for a specific realm.                         | _unset_            
                              |
+| `POLARIS_REALM_CONTEXT_HEADER_NAME`                            | Header name 
used for realm context.                              | `Polaris-Realm`          
                        |
 | `POLARIS_TOKEN_REFRESH_BUFFER_SECONDS`                         | Minimum 
remaining token lifetime before refreshing in seconds.   | `60.0`               
                            |
 | `POLARIS_HTTP_TIMEOUT_SECONDS`                                 | Default 
timeout in seconds for all HTTP requests.                | `30.0`               
                            |
 | `POLARIS_HTTP_CONNECT_TIMEOUT_SECONDS`                         | Timeout in 
seconds for establishing HTTP connections.            | `30.0`                  
                         |
@@ -166,8 +171,8 @@ uv run client.py http://localhost:8000/mcp \
 | `POLARIS_CONFIG_FILE`                                          | Path to a 
configuration file containing configuration variables. | `.polaris_mcp.env` in 
current working directory  |
 
 
-
 When OAuth variables are supplied, the server automatically acquires and 
refreshes tokens using the client credentials flow; otherwise a static bearer 
token is used if provided.
+Realm-specific variables (e.g., `POLARIS_REALM_${realm}_CLIENT_ID`) override 
the global settings for a given realm for client ID, client secret, token 
scope, and token URL. If realm-specific credentials are provided but 
incomplete, the server will not fall back to global credentials for that realm.
 
 ## Tools
 
diff --git a/mcp-server/polaris_mcp/authorization.py 
b/mcp-server/polaris_mcp/authorization.py
index 2b61894..6f38076 100644
--- a/mcp-server/polaris_mcp/authorization.py
+++ b/mcp-server/polaris_mcp/authorization.py
@@ -22,11 +22,12 @@
 from __future__ import annotations
 
 import json
+import os
 import threading
 import time
 from abc import ABC, abstractmethod
 from typing import Optional
-from urllib.parse import urlencode
+from urllib.parse import urlencode, urljoin
 
 import urllib3
 
@@ -35,7 +36,7 @@ class AuthorizationProvider(ABC):
     """Return Authorization header values for outgoing requests."""
 
     @abstractmethod
-    def authorization_header(self) -> Optional[str]: ...
+    def authorization_header(self, realm: Optional[str] = None) -> 
Optional[str]: ...
 
 
 class StaticAuthorizationProvider(AuthorizationProvider):
@@ -45,7 +46,7 @@ class StaticAuthorizationProvider(AuthorizationProvider):
         value = (token or "").strip()
         self._header = f"Bearer {value}" if value else None
 
-    def authorization_header(self) -> Optional[str]:
+    def authorization_header(self, realm: Optional[str] = None) -> 
Optional[str]:
         return self._header
 
 
@@ -54,59 +55,100 @@ class 
ClientCredentialsAuthorizationProvider(AuthorizationProvider):
 
     def __init__(
         self,
-        token_endpoint: str,
-        client_id: str,
-        client_secret: str,
-        scope: Optional[str],
+        base_url: str,
         http: urllib3.PoolManager,
         refresh_buffer_seconds: float,
         timeout: urllib3.Timeout,
     ) -> None:
-        self._token_endpoint = token_endpoint
-        self._client_id = client_id
-        self._client_secret = client_secret
-        self._scope = scope
+        self._base_url = base_url
         self._http = http
+        self._refresh_buffer_seconds = max(refresh_buffer_seconds, 0.0)
         self._timeout = timeout
         self._lock = threading.Lock()
-        self._cached: Optional[tuple[str, float]] = None  # (token, 
expires_at_epoch)
-        self._refresh_buffer_seconds = max(refresh_buffer_seconds, 0.0)
+        # {realm: (token, expires_at_epoch)}
+        self._cached: dict[str, tuple[str, float]] = {}
 
-    def authorization_header(self) -> Optional[str]:
-        token = self._current_token()
+    def authorization_header(self, realm: Optional[str] = None) -> 
Optional[str]:
+        token = self._get_token_from_realm(realm)
         return f"Bearer {token}" if token else None
 
-    def _current_token(self) -> Optional[str]:
-        now = time.time()
-        cached = self._cached
-        if not cached or cached[1] - self._refresh_buffer_seconds <= now:
-            with self._lock:
-                cached = self._cached
-                if (
-                    not cached
-                    or cached[1] - self._refresh_buffer_seconds <= time.time()
-                ):
-                    self._cached = cached = self._fetch_token()
-        return cached[0] if cached else None
-
-    def _fetch_token(self) -> tuple[str, float]:
+    def _get_token_from_realm(self, realm: Optional[str]) -> Optional[str]:
+        def needs_refresh(cached):
+            return (
+                cached is None
+                or cached[1] - self._refresh_buffer_seconds <= time.time()
+            )
+
+        cache_key = realm or ""
+        token = self._cached.get(cache_key)
+        # Token not expired
+        if not needs_refresh(token):
+            return token[0]
+        # Acquire lock and verify again if token expired
+        with self._lock:
+            token = self._cached.get(cache_key)
+            if needs_refresh(token):
+                credentials = self._get_credentials_from_realm(realm)
+                if not credentials:
+                    return None
+                token = self._fetch_token(realm, credentials)
+                self._cached[cache_key] = token
+        return token[0] if token else None
+
+    def _get_credentials_from_realm(
+        self, realm: Optional[str]
+    ) -> Optional[dict[str, str]]:
+        def get_env(key: str) -> Optional[str]:
+            val = os.getenv(key)
+            return val.strip() or None if val else None
+
+        def load_creds(realm: Optional[str] = None) -> dict[str, 
Optional[str]]:
+            prefix = f"POLARIS_REALM_{realm}_" if realm else "POLARIS_"
+            return {
+                "client_id": get_env(f"{prefix}CLIENT_ID"),
+                "client_secret": get_env(f"{prefix}CLIENT_SECRET"),
+                "scope": get_env(f"{prefix}TOKEN_SCOPE"),
+                "token_url": get_env(f"{prefix}TOKEN_URL"),
+            }
+
+        # Only use realm-specific credentials
+        if realm:
+            creds = load_creds(realm)
+            if creds["client_id"] and creds["client_secret"]:
+                return creds
+            return None
+        # No realm specified, use global credentials
+        creds = load_creds()
+        if creds["client_id"] and creds["client_secret"]:
+            return creds
+        return None
+
+    def _fetch_token(
+        self, realm: Optional[str], credentials: dict[str, str]
+    ) -> tuple[str, float]:
+        token_url = credentials.get("token_url") or urljoin(
+            self._base_url, "api/catalog/v1/oauth/tokens"
+        )
         payload = {
             "grant_type": "client_credentials",
-            "client_id": self._client_id,
-            "client_secret": self._client_secret,
+            "client_id": credentials["client_id"],
+            "client_secret": credentials["client_secret"],
         }
-        if self._scope:
-            payload["scope"] = self._scope
+        if credentials.get("scope"):
+            payload["scope"] = credentials["scope"]
 
         encoded = urlencode(payload)
+        header_name = os.getenv("POLARIS_REALM_CONTEXT_HEADER_NAME", 
"Polaris-Realm")
+        headers = {"Content-Type": "application/x-www-form-urlencoded"}
+        if realm:
+            headers[header_name] = realm
         response = self._http.request(
             "POST",
-            self._token_endpoint,
+            token_url,
             body=encoded,
-            headers={"Content-Type": "application/x-www-form-urlencoded"},
+            headers=headers,
             timeout=self._timeout,
         )
-
         if response.status != 200:
             raise RuntimeError(
                 f"OAuth token endpoint returned {response.status}: 
{response.data.decode('utf-8', errors='ignore')}"
@@ -132,7 +174,7 @@ class 
ClientCredentialsAuthorizationProvider(AuthorizationProvider):
 
 
 class _NoneAuthorizationProvider(AuthorizationProvider):
-    def authorization_header(self) -> Optional[str]:
+    def authorization_header(self, realm: Optional[str] = None) -> 
Optional[str]:
         return None
 
 
diff --git a/mcp-server/polaris_mcp/rest.py b/mcp-server/polaris_mcp/rest.py
index 8c9728d..6e02633 100644
--- a/mcp-server/polaris_mcp/rest.py
+++ b/mcp-server/polaris_mcp/rest.py
@@ -22,6 +22,7 @@
 from __future__ import annotations
 
 import json
+import os
 from typing import Any, Dict, List, Optional, Tuple
 from urllib.parse import urlencode, urljoin, urlsplit, urlunsplit, quote
 
@@ -230,6 +231,7 @@ class PolarisRestTool:
         query_params = arguments.get("query")
         headers_param = arguments.get("headers")
         body_node = arguments.get("body")
+        realm = arguments.get("realm")
 
         query = query_params if isinstance(query_params, dict) else None
         headers = headers_param if isinstance(headers_param, dict) else None
@@ -238,9 +240,14 @@ class PolarisRestTool:
 
         header_values = _merge_headers(headers)
         if not any(name.lower() == "authorization" for name in header_values):
-            token = self._authorization.authorization_header()
+            token = self._authorization.authorization_header(realm)
             if token:
                 header_values["Authorization"] = token
+        header_name = os.getenv("POLARIS_REALM_CONTEXT_HEADER_NAME", 
"Polaris-Realm")
+        if realm and not any(
+            name.lower() == header_name.lower() for name in header_values
+        ):
+            header_values[header_name] = realm
 
         body_text = _serialize_body(body_node)
         if body_text is not None and not any(
diff --git a/mcp-server/polaris_mcp/server.py b/mcp-server/polaris_mcp/server.py
index a338612..7ebb812 100644
--- a/mcp-server/polaris_mcp/server.py
+++ b/mcp-server/polaris_mcp/server.py
@@ -27,7 +27,7 @@ import logging.config
 import argparse
 import os
 from typing import Any, Mapping, MutableMapping, Sequence, Optional
-from urllib.parse import urljoin, urlparse
+from urllib.parse import urlparse
 
 import urllib3
 from fastmcp import FastMCP
@@ -164,6 +164,7 @@ def create_server() -> FastMCP:
         query: Mapping[str, str | Sequence[str]] | None = None,
         headers: Mapping[str, str | Sequence[str]] | None = None,
         body: Any | None = None,
+        realm: str | None = None,
     ) -> FastMcpToolResult:
         return _call_tool(
             table_tool,
@@ -177,6 +178,7 @@ def create_server() -> FastMCP:
                 "query": query,
                 "headers": headers,
                 "body": body,
+                "realm": realm,
             },
             transforms={
                 "namespace": _normalize_namespace,
@@ -198,6 +200,7 @@ def create_server() -> FastMCP:
         query: Mapping[str, str | Sequence[str]] | None = None,
         headers: Mapping[str, str | Sequence[str]] | None = None,
         body: Any | None = None,
+        realm: str | None = None,
     ) -> FastMcpToolResult:
         return _call_tool(
             namespace_tool,
@@ -210,6 +213,7 @@ def create_server() -> FastMCP:
                 "query": query,
                 "headers": headers,
                 "body": body,
+                "realm": realm,
             },
             transforms={
                 "namespace": _normalize_namespace,
@@ -231,6 +235,7 @@ def create_server() -> FastMCP:
         query: Mapping[str, str | Sequence[str]] | None = None,
         headers: Mapping[str, str | Sequence[str]] | None = None,
         body: Any | None = None,
+        realm: str | None = None,
     ) -> FastMcpToolResult:
         return _call_tool(
             principal_tool,
@@ -241,6 +246,7 @@ def create_server() -> FastMCP:
                 "query": query,
                 "headers": headers,
                 "body": body,
+                "realm": realm,
             },
             transforms={
                 "query": _copy_mapping,
@@ -262,6 +268,7 @@ def create_server() -> FastMCP:
         query: Mapping[str, str | Sequence[str]] | None = None,
         headers: Mapping[str, str | Sequence[str]] | None = None,
         body: Any | None = None,
+        realm: str | None = None,
     ) -> FastMcpToolResult:
         return _call_tool(
             principal_role_tool,
@@ -273,6 +280,7 @@ def create_server() -> FastMCP:
                 "query": query,
                 "headers": headers,
                 "body": body,
+                "realm": realm,
             },
             transforms={
                 "query": _copy_mapping,
@@ -293,6 +301,7 @@ def create_server() -> FastMCP:
         query: Mapping[str, str | Sequence[str]] | None = None,
         headers: Mapping[str, str | Sequence[str]] | None = None,
         body: Any | None = None,
+        realm: str | None = None,
     ) -> FastMcpToolResult:
         return _call_tool(
             catalog_role_tool,
@@ -305,6 +314,7 @@ def create_server() -> FastMCP:
                 "query": query,
                 "headers": headers,
                 "body": body,
+                "realm": realm,
             },
             transforms={
                 "query": _copy_mapping,
@@ -326,6 +336,7 @@ def create_server() -> FastMCP:
         query: Mapping[str, str | Sequence[str]] | None = None,
         headers: Mapping[str, str | Sequence[str]] | None = None,
         body: Any | None = None,
+        realm: str | None = None,
     ) -> FastMcpToolResult:
         return _call_tool(
             policy_tool,
@@ -339,6 +350,7 @@ def create_server() -> FastMCP:
                 "query": query,
                 "headers": headers,
                 "body": body,
+                "realm": realm,
             },
             transforms={
                 "namespace": _normalize_namespace,
@@ -359,6 +371,7 @@ def create_server() -> FastMCP:
         query: Mapping[str, str | Sequence[str]] | None = None,
         headers: Mapping[str, str | Sequence[str]] | None = None,
         body: Any | None = None,
+        realm: str | None = None,
     ) -> FastMcpToolResult:
         return _call_tool(
             catalog_tool,
@@ -368,6 +381,7 @@ def create_server() -> FastMCP:
                 "query": query,
                 "headers": headers,
                 "body": body,
+                "realm": realm,
             },
             transforms={
                 "query": _copy_mapping,
@@ -482,23 +496,21 @@ def _resolve_http_timeout() -> urllib3.Timeout:
 
 
 def _resolve_authorization_provider(
-    base_url: str, http: urllib3.PoolManager, timeout: urllib3.Timeout
+    base_url: str,
+    http: urllib3.PoolManager,
+    timeout: urllib3.Timeout,
 ) -> AuthorizationProvider:
     token = _resolve_token()
     if token:
         return StaticAuthorizationProvider(token)
 
-    client_id = _first_non_blank(
-        os.getenv("POLARIS_CLIENT_ID"),
-    )
-    client_secret = _first_non_blank(
-        os.getenv("POLARIS_CLIENT_SECRET"),
+    client_id = _first_non_blank(os.getenv("POLARIS_CLIENT_ID"))
+    client_secret = _first_non_blank(os.getenv("POLARIS_CLIENT_SECRET"))
+    has_realm_credentials = any(
+        key.startswith("POLARIS_REALM_") for key in os.environ.keys()
     )
 
-    if client_id and client_secret:
-        scope = _first_non_blank(os.getenv("POLARIS_TOKEN_SCOPE"))
-        token_url = _first_non_blank(os.getenv("POLARIS_TOKEN_URL"))
-        endpoint = token_url or urljoin(base_url, 
"api/catalog/v1/oauth/tokens")
+    if client_id and client_secret or has_realm_credentials:
         refresh_buffer_seconds = DEFAULT_TOKEN_REFRESH_BUFFER_SECONDS
         refresh_buffer_seconds_str = 
os.getenv("POLARIS_TOKEN_REFRESH_BUFFER_SECONDS")
         if refresh_buffer_seconds_str:
@@ -507,10 +519,7 @@ def _resolve_authorization_provider(
             except ValueError:
                 pass
         return ClientCredentialsAuthorizationProvider(
-            token_endpoint=endpoint,
-            client_id=client_id,
-            client_secret=client_secret,
-            scope=scope,
+            base_url=base_url,
             http=http,
             refresh_buffer_seconds=refresh_buffer_seconds,
             timeout=timeout,
diff --git a/mcp-server/polaris_mcp/tools/catalog.py 
b/mcp-server/polaris_mcp/tools/catalog.py
index cdbd909..5c0dc8f 100644
--- a/mcp-server/polaris_mcp/tools/catalog.py
+++ b/mcp-server/polaris_mcp/tools/catalog.py
@@ -106,6 +106,10 @@ class PolarisCatalogTool(McpTool):
         copy_if_object(arguments.get("query"), delegate_args, "query")
         copy_if_object(arguments.get("headers"), delegate_args, "headers")
 
+        realm = arguments.get("realm")
+        if isinstance(realm, str) and realm.strip():
+            delegate_args["realm"] = realm
+
         if normalized == "list":
             delegate_args["method"] = "GET"
             delegate_args["path"] = "catalogs"
diff --git a/mcp-server/polaris_mcp/tools/catalog_role.py 
b/mcp-server/polaris_mcp/tools/catalog_role.py
index eeb0111..8d13ec2 100644
--- a/mcp-server/polaris_mcp/tools/catalog_role.py
+++ b/mcp-server/polaris_mcp/tools/catalog_role.py
@@ -127,6 +127,10 @@ class PolarisCatalogRoleTool(McpTool):
         copy_if_object(arguments.get("query"), delegate_args, "query")
         copy_if_object(arguments.get("headers"), delegate_args, "headers")
 
+        realm = arguments.get("realm")
+        if isinstance(realm, str) and realm.strip():
+            delegate_args["realm"] = realm
+
         base_path = f"catalogs/{catalog}/catalog-roles"
 
         if normalized == "list":
diff --git a/mcp-server/polaris_mcp/tools/namespace.py 
b/mcp-server/polaris_mcp/tools/namespace.py
index 02f312d..7d34070 100644
--- a/mcp-server/polaris_mcp/tools/namespace.py
+++ b/mcp-server/polaris_mcp/tools/namespace.py
@@ -132,6 +132,10 @@ class PolarisNamespaceTool(McpTool):
         copy_if_object(arguments.get("query"), delegate_args, "query")
         copy_if_object(arguments.get("headers"), delegate_args, "headers")
 
+        realm = arguments.get("realm")
+        if isinstance(realm, str) and realm.strip():
+            delegate_args["realm"] = realm
+
         if normalized == "list":
             self._handle_list(delegate_args, catalog)
         elif normalized == "get":
diff --git a/mcp-server/polaris_mcp/tools/policy.py 
b/mcp-server/polaris_mcp/tools/policy.py
index 5463aa8..f8eee13 100644
--- a/mcp-server/polaris_mcp/tools/policy.py
+++ b/mcp-server/polaris_mcp/tools/policy.py
@@ -143,6 +143,10 @@ class PolarisPolicyTool(McpTool):
         copy_if_object(arguments.get("query"), delegate_args, "query")
         copy_if_object(arguments.get("headers"), delegate_args, "headers")
 
+        realm = arguments.get("realm")
+        if isinstance(realm, str) and realm.strip():
+            delegate_args["realm"] = realm
+
         if normalized == "list":
             self._require_namespace(namespace, "list")
             self._handle_list(delegate_args, catalog, namespace)
diff --git a/mcp-server/polaris_mcp/tools/principal.py 
b/mcp-server/polaris_mcp/tools/principal.py
index 60470fb..9bea911 100644
--- a/mcp-server/polaris_mcp/tools/principal.py
+++ b/mcp-server/polaris_mcp/tools/principal.py
@@ -129,6 +129,10 @@ class PolarisPrincipalTool(McpTool):
         copy_if_object(arguments.get("query"), delegate_args, "query")
         copy_if_object(arguments.get("headers"), delegate_args, "headers")
 
+        realm = arguments.get("realm")
+        if isinstance(realm, str) and realm.strip():
+            delegate_args["realm"] = realm
+
         if normalized == "list":
             self._handle_list(delegate_args)
         elif normalized == "create":
diff --git a/mcp-server/polaris_mcp/tools/principal_role.py 
b/mcp-server/polaris_mcp/tools/principal_role.py
index 2941769..b23cce8 100644
--- a/mcp-server/polaris_mcp/tools/principal_role.py
+++ b/mcp-server/polaris_mcp/tools/principal_role.py
@@ -135,6 +135,10 @@ class PolarisPrincipalRoleTool(McpTool):
         copy_if_object(arguments.get("query"), delegate_args, "query")
         copy_if_object(arguments.get("headers"), delegate_args, "headers")
 
+        realm = arguments.get("realm")
+        if isinstance(realm, str) and realm.strip():
+            delegate_args["realm"] = realm
+
         if normalized == "list":
             delegate_args["method"] = "GET"
             delegate_args["path"] = "principal-roles"
diff --git a/mcp-server/polaris_mcp/tools/table.py 
b/mcp-server/polaris_mcp/tools/table.py
index a92a40b..d021a2f 100644
--- a/mcp-server/polaris_mcp/tools/table.py
+++ b/mcp-server/polaris_mcp/tools/table.py
@@ -123,6 +123,10 @@ class PolarisTableTool(McpTool):
         copy_if_object(arguments.get("query"), delegate_args, "query")
         copy_if_object(arguments.get("headers"), delegate_args, "headers")
 
+        realm = arguments.get("realm")
+        if isinstance(realm, str) and realm.strip():
+            delegate_args["realm"] = realm
+
         if normalized == "list":
             self._handle_list(delegate_args, catalog, namespace)
         elif normalized == "get":
diff --git a/mcp-server/tests/test_authorization.py 
b/mcp-server/tests/test_authorization.py
index dfb0945..3c89d32 100644
--- a/mcp-server/tests/test_authorization.py
+++ b/mcp-server/tests/test_authorization.py
@@ -51,6 +51,9 @@ def test_none_authorization_provider_returns_none() -> None:
 def test_client_credentials_fetches_and_caches_tokens(
     monkeypatch: pytest.MonkeyPatch,
 ) -> None:
+    monkeypatch.setenv("POLARIS_CLIENT_ID", "client")
+    monkeypatch.setenv("POLARIS_CLIENT_SECRET", "secret")
+    monkeypatch.setenv("POLARIS_TOKEN_URL", "https://auth/token";)
     http = mock.Mock()
     now = time.time()
     response = SimpleNamespace(
@@ -60,12 +63,9 @@ def test_client_credentials_fetches_and_caches_tokens(
     http.request.return_value = response
 
     provider = ClientCredentialsAuthorizationProvider(
-        token_endpoint="https://auth/token";,
-        client_id="client",
-        client_secret="secret",
-        scope=None,
+        base_url="https://polaris/";,
         http=http,
-        refresh_buffer_seconds=0.0,
+        refresh_buffer_seconds=60.0,
         timeout=mock.sentinel.timeout,
     )
 
@@ -77,7 +77,9 @@ def test_client_credentials_fetches_and_caches_tokens(
     assert header2 == "Bearer abc"
 
     http.request.assert_called_once()
-    body = http.request.call_args.kwargs["body"]
+    args, kwargs = http.request.call_args
+    assert args[1] == "https://auth/token";
+    body = kwargs["body"]
     assert "grant_type=client_credentials" in body
     assert "client_id=client" in body
     assert "client_secret=secret" in body
@@ -96,7 +98,44 @@ def test_client_credentials_fetches_and_caches_tokens(
     http.request.assert_called_once()
 
 
-def test_client_credentials_refresh_buffer() -> None:
+def test_client_credentials_fetches_and_caches_realm_specific_token(
+    monkeypatch: pytest.MonkeyPatch,
+) -> None:
+    http_mock = mock.Mock()
+    provider = ClientCredentialsAuthorizationProvider(
+        base_url="https://polaris/";,
+        http=http_mock,
+        refresh_buffer_seconds=60.0,
+        timeout=mock.sentinel.timeout,
+    )
+    monkeypatch.setenv("POLARIS_REALM_TEST_REALM_CLIENT_ID", "realm_client")
+    monkeypatch.setenv("POLARIS_REALM_TEST_REALM_CLIENT_SECRET", 
"realm_secret")
+    monkeypatch.setenv("POLARIS_REALM_TEST_REALM_TOKEN_URL", 
"https://realm-auth/token";)
+
+    now = time.time()
+    response = SimpleNamespace(
+        status=200,
+        data=json.dumps({"access_token": "realm_token", "expires_in": 
3600}).encode(
+            "utf-8"
+        ),
+    )
+    http_mock.request.return_value = response
+
+    with mock.patch("time.time", return_value=now):
+        header = provider.authorization_header(realm="TEST_REALM")
+
+    assert header == "Bearer realm_token"
+    http_mock.request.assert_called_once()
+    args, kwargs = http_mock.request.call_args
+    assert args[1] == "https://realm-auth/token";
+    assert "client_id=realm_client" in kwargs["body"]
+    assert "Polaris-Realm" in kwargs["headers"]
+    assert kwargs["headers"]["Polaris-Realm"] == "TEST_REALM"
+
+
+def test_client_credentials_refresh_buffer(monkeypatch: pytest.MonkeyPatch) -> 
None:
+    monkeypatch.setenv("POLARIS_CLIENT_ID", "client")
+    monkeypatch.setenv("POLARIS_CLIENT_SECRET", "secret")
     http = mock.Mock()
     now = time.time()
     expires_in = 120
@@ -111,10 +150,7 @@ def test_client_credentials_refresh_buffer() -> None:
     http.request.return_value = response
 
     provider = ClientCredentialsAuthorizationProvider(
-        token_endpoint="https://auth/token";,
-        client_id="client",
-        client_secret="secret",
-        scope=None,
+        base_url="https://polaris/";,
         http=http,
         refresh_buffer_seconds=refresh_buffer,
         timeout=mock.sentinel.timeout,
@@ -154,6 +190,18 @@ def test_client_credentials_refresh_buffer() -> None:
     http.request.assert_not_called()
 
 
+def test_client_credentials_returns_none_if_no_credentials() -> None:
+    http_mock = mock.Mock()
+    provider = ClientCredentialsAuthorizationProvider(
+        base_url="https://polaris/";,
+        http=http_mock,
+        refresh_buffer_seconds=60.0,
+        timeout=mock.sentinel.timeout,
+    )
+    assert provider.authorization_header() is None
+    assert provider.authorization_header(realm="foo") is None
+
+
 @pytest.mark.parametrize(
     "payload,expected_message",
     [
@@ -163,8 +211,10 @@ def test_client_credentials_refresh_buffer() -> None:
     ],
 )
 def test_client_credentials_rejects_invalid_responses(
-    payload: object, expected_message: str
+    payload: object, expected_message: str, monkeypatch: pytest.MonkeyPatch
 ) -> None:
+    monkeypatch.setenv("POLARIS_CLIENT_ID", "client")
+    monkeypatch.setenv("POLARIS_CLIENT_SECRET", "secret")
     http = mock.Mock()
     if isinstance(payload, str):
         data = payload.encode("utf-8")
@@ -173,10 +223,7 @@ def test_client_credentials_rejects_invalid_responses(
     http.request.return_value = SimpleNamespace(status=200, data=data)
 
     provider = ClientCredentialsAuthorizationProvider(
-        token_endpoint="https://auth/token";,
-        client_id="client",
-        client_secret="secret",
-        scope=None,
+        base_url="https://polaris/";,
         http=http,
         refresh_buffer_seconds=0.0,
         timeout=mock.sentinel.timeout,
@@ -186,15 +233,16 @@ def test_client_credentials_rejects_invalid_responses(
         provider.authorization_header()
 
 
-def test_client_credentials_errors_on_non_200_status() -> None:
+def test_client_credentials_errors_on_non_200_status(
+    monkeypatch: pytest.MonkeyPatch,
+) -> None:
+    monkeypatch.setenv("POLARIS_CLIENT_ID", "client")
+    monkeypatch.setenv("POLARIS_CLIENT_SECRET", "secret")
     http = mock.Mock()
     http.request.return_value = SimpleNamespace(status=500, data=b"boom")
 
     provider = ClientCredentialsAuthorizationProvider(
-        token_endpoint="https://auth/token";,
-        client_id="client",
-        client_secret="secret",
-        scope=None,
+        base_url="https://polaris/";,
         http=http,
         refresh_buffer_seconds=0.0,
         timeout=mock.sentinel.timeout,
@@ -202,3 +250,123 @@ def test_client_credentials_errors_on_non_200_status() -> 
None:
 
     with pytest.raises(RuntimeError, match="500"):
         provider.authorization_header()
+
+
+def test_client_credentials_caches_tokens_separately_for_each_realm(
+    monkeypatch: pytest.MonkeyPatch,
+) -> None:
+    http_mock = mock.Mock()
+    provider = ClientCredentialsAuthorizationProvider(
+        base_url="https://polaris/";,
+        http=http_mock,
+        refresh_buffer_seconds=60.0,
+        timeout=mock.sentinel.timeout,
+    )
+    # Global creds
+    monkeypatch.setenv("POLARIS_CLIENT_ID", "global_client")
+    monkeypatch.setenv("POLARIS_CLIENT_SECRET", "global_secret")
+    # Realm creds
+    monkeypatch.setenv("POLARIS_REALM_realm1_CLIENT_ID", "realm1_client")
+    monkeypatch.setenv("POLARIS_REALM_realm1_CLIENT_SECRET", "realm1_secret")
+
+    # First call for global
+    http_mock.request.return_value = SimpleNamespace(
+        status=200,
+        data=json.dumps({"access_token": "global_token"}).encode("utf-8"),
+    )
+    assert provider.authorization_header() == "Bearer global_token"
+    http_mock.request.assert_called_once()
+    assert "client_id=global_client" in 
http_mock.request.call_args.kwargs["body"]
+
+    # First call for realm1
+    http_mock.request.return_value = SimpleNamespace(
+        status=200,
+        data=json.dumps({"access_token": "realm1_token"}).encode("utf-8"),
+    )
+    assert provider.authorization_header(realm="realm1") == "Bearer 
realm1_token"
+    assert http_mock.request.call_count == 2
+    assert "client_id=realm1_client" in 
http_mock.request.call_args.kwargs["body"]
+
+    # Second call for global should hit cache
+    assert provider.authorization_header() == "Bearer global_token"
+    assert http_mock.request.call_count == 2
+
+    # Second call for realm1 should hit cache
+    assert provider.authorization_header(realm="realm1") == "Bearer 
realm1_token"
+    assert http_mock.request.call_count == 2
+
+
+def test_with_realm_header(
+    monkeypatch: pytest.MonkeyPatch,
+) -> None:
+    realm_name = "TEST_REALM"
+    http = mock.Mock()
+    http.request.return_value = SimpleNamespace(
+        status=200,
+        data=json.dumps({"access_token": "token", "expires_in": 
3600}).encode("utf-8"),
+    )
+    monkeypatch.setenv(f"POLARIS_REALM_{realm_name}_CLIENT_ID", "client")
+    monkeypatch.setenv(f"POLARIS_REALM_{realm_name}_CLIENT_SECRET", "secret")
+    provider = ClientCredentialsAuthorizationProvider(
+        base_url="https://polaris/";,
+        http=http,
+        refresh_buffer_seconds=60.0,
+        timeout=mock.sentinel.timeout,
+    )
+    provider.authorization_header(realm=realm_name)
+    call_args = http.request.call_args
+    headers = call_args[1]["headers"]
+    assert headers["Polaris-Realm"] == realm_name
+
+
+def test_with_custom_realm_header(
+    monkeypatch: pytest.MonkeyPatch,
+) -> None:
+    realm_name = "TEST_REALM"
+    monkeypatch.setenv("POLARIS_REALM_CONTEXT_HEADER_NAME", "X-Polaris-Realm")
+    http = mock.Mock()
+    http.request.return_value = SimpleNamespace(
+        status=200,
+        data=json.dumps({"access_token": "token", "expires_in": 
3600}).encode("utf-8"),
+    )
+    monkeypatch.setenv(f"POLARIS_REALM_{realm_name}_CLIENT_ID", "client")
+    monkeypatch.setenv(f"POLARIS_REALM_{realm_name}_CLIENT_SECRET", "secret")
+    provider = ClientCredentialsAuthorizationProvider(
+        base_url="https://polaris/";,
+        http=http,
+        refresh_buffer_seconds=60.0,
+        timeout=mock.sentinel.timeout,
+    )
+    provider.authorization_header(realm=realm_name)
+    call_args = http.request.call_args
+    headers = call_args[1]["headers"]
+    assert headers["X-Polaris-Realm"] == realm_name
+
+
+def test_two_realms_one_incomplete(monkeypatch: pytest.MonkeyPatch) -> None:
+    realm1_name = "TEST_REALM"
+    realm2_name = "TEST2_REALM"
+    http = mock.Mock()
+    provider = ClientCredentialsAuthorizationProvider(
+        base_url="https://polaris/";,
+        http=http,
+        refresh_buffer_seconds=60.0,
+        timeout=mock.sentinel.timeout,
+    )
+    # Realm 1 – complete credentials
+    monkeypatch.setenv(f"POLARIS_REALM_{realm1_name}_CLIENT_ID", "client")
+    monkeypatch.setenv(f"POLARIS_REALM_{realm1_name}_CLIENT_SECRET", "secret")
+    # Realm 2 – missing secret
+    monkeypatch.setenv(f"POLARIS_REALM_{realm2_name}_CLIENT_ID", "client2")
+    # Mock response for realm 1
+    http.request.return_value = SimpleNamespace(
+        status=200,
+        data=json.dumps({"access_token": "token", "expires_in": 
3600}).encode("utf-8"),
+    )
+    # Realm 1 should succeed
+    assert provider.authorization_header(realm=f"{realm1_name}") == "Bearer 
token"
+    assert http.request.call_count == 1
+    # Realm 2 should return None and not trigger an HTTP request
+    http.request.reset_mock()
+    assert provider.authorization_header(realm=f"{realm2_name}") is None
+    assert http.request.call_count == 0
diff --git a/mcp-server/tests/test_rest_tool.py 
b/mcp-server/tests/test_rest_tool.py
index cbd3682..07de663 100644
--- a/mcp-server/tests/test_rest_tool.py
+++ b/mcp-server/tests/test_rest_tool.py
@@ -169,3 +169,68 @@ def test_call_requires_non_empty_path() -> None:
         tool.call({"method": "GET"})
 
     http.request.assert_not_called()
+
+
+def test_call_with_realm() -> None:
+    tool, http, auth = _create_tool()
+    http.request.return_value = _build_response(status=200, body="{}")
+    tool.call(
+        {
+            "method": "GET",
+            "path": "namespace",
+            "realm": "realm1",
+        }
+    )
+    auth.authorization_header.assert_called_once_with("realm1")
+    call_args = http.request.call_args
+    headers = call_args[1]["headers"]
+    assert headers["Polaris-Realm"] == "realm1"
+
+
+def test_call_with_existed_realm() -> None:
+    tool, http, auth = _create_tool()
+    http.request.return_value = _build_response(status=200, body="{}")
+    tool.call(
+        {
+            "method": "GET",
+            "path": "namespace",
+            "headers": {"Polaris-Realm": "existing_realm"},
+            "realm": "realm1",
+        }
+    )
+    call_args = http.request.call_args
+    headers = call_args[1]["headers"]
+    assert headers["Polaris-Realm"] == "existing_realm"
+
+
+def test_call_with_custom_realm_header(
+    monkeypatch: pytest.MonkeyPatch,
+) -> None:
+    monkeypatch.setenv("POLARIS_REALM_CONTEXT_HEADER_NAME", "X-Polaris-Realm")
+    tool, http, auth = _create_tool()
+    http.request.return_value = _build_response(status=200, body="{}")
+    tool.call(
+        {
+            "method": "GET",
+            "path": "namespace",
+            "realm": "realm1",
+        }
+    )
+    call_args = http.request.call_args
+    headers = call_args[1]["headers"]
+    assert "Polaris-Realm" not in headers
+    assert headers["X-Polaris-Realm"] == "realm1"
+
+
+def test_call_without_provide_realm() -> None:
+    tool, http, auth = _create_tool()
+    http.request.return_value = _build_response(status=200, body="{}")
+    tool.call(
+        {
+            "method": "GET",
+            "path": "namespace",
+        }
+    )
+    call_args = http.request.call_args
+    headers = call_args[1]["headers"]
+    assert "Polaris-Realm" not in headers
diff --git a/mcp-server/tests/test_server.py b/mcp-server/tests/test_server.py
index 0ba55e0..5336b7d 100644
--- a/mcp-server/tests/test_server.py
+++ b/mcp-server/tests/test_server.py
@@ -242,10 +242,7 @@ class TestAuthorizationProviderResolution:
 
         assert provider is fake_provider
         mock_factory.assert_called_once_with(
-            token_endpoint="https://oauth/token";,
-            client_id="client",
-            client_secret="secret",
-            scope="scope",
+            base_url="https://base/";,
             http=fake_http,
             refresh_buffer_seconds=60.0,
             timeout=mock.sentinel.timeout,


Reply via email to