This is an automated email from the ASF dual-hosted git repository.

yuqi1129 pushed a commit to branch feat/mcp-governance-task3-6
in repository https://gitbox.apache.org/repos/asf/gravitino.git

commit e135d011d20ab2717e6ddf9f0d40f49a0c518409
Author: yuqi <[email protected]>
AuthorDate: Wed Jun 10 20:30:52 2026 +0800

    [#11568] feat(mcp-server): structured per-tool-call audit logging 
attributed to principal
    
    - core/audit.py: AuditLogger with emit() writing JSON records to 
gravitino.mcp.audit logger.
      Fields: timestamp, principal (bearer:<8-char-prefix> or anonymous), tool, 
outcome, error_type.
    - server.py: AuditMiddleware overrides on_call_tool; extracts principal 
from HTTP
      Authorization header via get_http_request() (falls back to anonymous in 
stdio mode);
      emits allow on success and deny on exception.
    - main.py: configure gravitino-mcp-audit.log file handler for the audit 
logger.
    - tests/unit/test_audit.py: 9 tests covering emit structure, principal 
extraction,
      middleware allow/deny integration.
---
 mcp-server/mcp_server/core/audit.py |  64 ++++++++++++
 mcp-server/mcp_server/main.py       |   6 ++
 mcp-server/mcp_server/server.py     |  50 +++++++++-
 mcp-server/tests/unit/test_audit.py | 189 ++++++++++++++++++++++++++++++++++++
 4 files changed, 308 insertions(+), 1 deletion(-)

diff --git a/mcp-server/mcp_server/core/audit.py 
b/mcp-server/mcp_server/core/audit.py
new file mode 100644
index 0000000000..f65b71cdff
--- /dev/null
+++ b/mcp-server/mcp_server/core/audit.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+import logging
+from datetime import datetime, timezone
+
+_audit_logger = logging.getLogger("gravitino.mcp.audit")
+
+
+def _extract_principal(authorization: str) -> str:
+    """Derive a display principal from a raw Authorization header value.
+
+    - "Bearer <token>" → "bearer:<first-8-chars-of-token>"
+    - empty / missing  → "anonymous"
+    """
+    if not authorization:
+        return "anonymous"
+    parts = authorization.split()
+    if len(parts) == 2 and parts[0].lower() == "bearer":
+        token = parts[1]
+        return f"bearer:{token[:8]}"
+    return "anonymous"
+
+
+def emit(
+    *,
+    principal: str,
+    tool: str,
+    outcome: str,
+    error_type: str = "",
+) -> None:
+    """Write one structured JSON audit record to the audit logger.
+
+    Args:
+        principal: Identity derived from the request (e.g. "bearer:abc12345" 
or "anonymous").
+        tool:      MCP tool name that was invoked.
+        outcome:   "allow" for successful calls, "deny" for authorization 
failures.
+        error_type: Exception class name when outcome is "deny", empty 
otherwise.
+    """
+    record = {
+        "timestamp": datetime.now(timezone.utc).isoformat(),
+        "principal": principal,
+        "tool": tool,
+        "outcome": outcome,
+    }
+    if error_type:
+        record["error_type"] = error_type
+
+    _audit_logger.info(json.dumps(record))
diff --git a/mcp-server/mcp_server/main.py b/mcp-server/mcp_server/main.py
index b1200b3180..ca60d8930f 100644
--- a/mcp-server/mcp_server/main.py
+++ b/mcp-server/mcp_server/main.py
@@ -46,6 +46,12 @@ def _init_logging(setting: Setting):
         format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
         datefmt="%Y-%m-%d %H:%M:%S",
     )
+    # Separate file handler for structured audit records (one JSON line per 
entry).
+    audit_handler = logging.FileHandler("gravitino-mcp-audit.log")
+    audit_handler.setLevel(logging.INFO)
+    audit_handler.setFormatter(logging.Formatter("%(message)s"))
+    logging.getLogger("gravitino.mcp.audit").addHandler(audit_handler)
+    logging.getLogger("gravitino.mcp.audit").propagate = False
 
 
 def _comma_separated_set(value) -> set:
diff --git a/mcp-server/mcp_server/server.py b/mcp-server/mcp_server/server.py
index 9cce5f556b..8d66fec936 100644
--- a/mcp-server/mcp_server/server.py
+++ b/mcp-server/mcp_server/server.py
@@ -16,23 +16,71 @@
 # under the License.
 
 import asyncio
+import contextlib
 import logging
 from contextlib import asynccontextmanager
 from typing import AsyncIterator
 from urllib.parse import urlparse
 
+import mcp.types as mt
 from fastmcp import FastMCP
 from fastmcp.server.middleware.error_handling import ErrorHandlingMiddleware
 from fastmcp.server.middleware.logging import (
     LoggingMiddleware,
 )
+from fastmcp.server.middleware.middleware import (
+    CallNext,
+    Middleware,
+    MiddlewareContext,
+)
 from fastmcp.server.middleware.timing import TimingMiddleware
+from fastmcp.tools.base import ToolResult
 
+from mcp_server.core import audit
 from mcp_server.core.context import GravitinoContext
 from mcp_server.core.setting import Setting
 from mcp_server.tools import load_tools
 
 
+def _get_principal_from_request() -> str:
+    """Extract the principal from the current HTTP request's Authorization 
header.
+
+    Returns "anonymous" in stdio mode (no HTTP request) or when no token is 
present.
+    """
+    try:
+        from fastmcp.server.dependencies import get_http_request
+
+        request = get_http_request()
+        authorization = request.headers.get("authorization", "")
+        return audit._extract_principal(authorization)
+    except Exception:  # noqa: BLE001 – stdio mode or no request context
+        return "anonymous"
+
+
+class AuditMiddleware(Middleware):
+    """Emit a structured audit record for every tool invocation."""
+
+    async def on_call_tool(
+        self,
+        context: MiddlewareContext[mt.CallToolRequestParams],
+        call_next: CallNext[mt.CallToolRequestParams, ToolResult],
+    ) -> ToolResult:
+        tool_name = context.message.name if context.message else "unknown"
+        principal = _get_principal_from_request()
+        try:
+            result = await call_next(context)
+            audit.emit(principal=principal, tool=tool_name, outcome="allow")
+            return result
+        except Exception as exc:
+            audit.emit(
+                principal=principal,
+                tool=tool_name,
+                outcome="deny",
+                error_type=type(exc).__name__,
+            )
+            raise
+
+
 def _create_lifespan_manager(gravitino_context: GravitinoContext):
 
     @asynccontextmanager
@@ -56,10 +104,10 @@ def _create_gravitino_mcp(setting: Setting) -> FastMCP:
             lifespan=_create_lifespan_manager(GravitinoContext(setting)),
         )
 
+    mcp.add_middleware(AuditMiddleware())
     mcp.add_middleware(
         LoggingMiddleware(include_payloads=True, max_payload_length=1000)
     )
-
     mcp.add_middleware(TimingMiddleware())
     mcp.add_middleware(
         ErrorHandlingMiddleware(
diff --git a/mcp-server/tests/unit/test_audit.py 
b/mcp-server/tests/unit/test_audit.py
new file mode 100644
index 0000000000..37cdef29cb
--- /dev/null
+++ b/mcp-server/tests/unit/test_audit.py
@@ -0,0 +1,189 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import asyncio
+import json
+import logging
+import unittest
+from unittest.mock import patch
+
+from fastmcp import Client
+
+from mcp_server.client.factory import RESTClientFactory
+from mcp_server.client.plain.exception import GravitinoException
+from mcp_server.client.plain.plain_rest_client_operation import 
PlainRESTClientOperation
+from mcp_server.core import audit
+from mcp_server.core.setting import Setting
+from mcp_server.server import GravitinoMCPServer
+from tests.unit.tools import MockOperation
+
+
+class TestAuditEmit(unittest.TestCase):
+    """Unit tests for the audit.emit() function."""
+
+    def setUp(self):
+        # tests/unit/tools/__init__.py calls logging.disable(logging.INFO).
+        # Re-enable here so audit records flow through; restore on teardown.
+        logging.disable(logging.NOTSET)
+        self.log_records = []
+        self.handler = _CapturingHandler(self.log_records)
+        audit_logger = logging.getLogger("gravitino.mcp.audit")
+        audit_logger.addHandler(self.handler)
+        audit_logger.setLevel(logging.INFO)
+        audit_logger.propagate = False
+
+    def tearDown(self):
+        logging.disable(logging.INFO)
+        audit_logger = logging.getLogger("gravitino.mcp.audit")
+        audit_logger.removeHandler(self.handler)
+        audit_logger.propagate = True
+
+    def test_allow_record_structure(self):
+        """emit() writes a JSON record with all required fields on allow."""
+        audit.emit(principal="bearer:abc12345", tool="list_catalogs", 
outcome="allow")
+
+        self.assertEqual(len(self.log_records), 1)
+        record = json.loads(self.log_records[0])
+        self.assertEqual(record["principal"], "bearer:abc12345")
+        self.assertEqual(record["tool"], "list_catalogs")
+        self.assertEqual(record["outcome"], "allow")
+        self.assertIn("timestamp", record)
+        self.assertNotIn("error_type", record)
+
+    def test_deny_record_includes_error_type(self):
+        """emit() includes error_type in the record when outcome is deny."""
+        audit.emit(
+            principal="bearer:xyz99999",
+            tool="create_tag",
+            outcome="deny",
+            error_type="GravitinoException",
+        )
+
+        record = json.loads(self.log_records[0])
+        self.assertEqual(record["outcome"], "deny")
+        self.assertEqual(record["error_type"], "GravitinoException")
+
+    def test_anonymous_principal(self):
+        """emit() works with anonymous principal."""
+        audit.emit(principal="anonymous", tool="get_list_of_catalogs", 
outcome="allow")
+        record = json.loads(self.log_records[0])
+        self.assertEqual(record["principal"], "anonymous")
+
+
+class TestExtractPrincipal(unittest.TestCase):
+    """Unit tests for audit._extract_principal()."""
+
+    def test_bearer_token_truncated_to_8_chars(self):
+        self.assertEqual(
+            audit._extract_principal("Bearer abcdefghijklmnop"),
+            "bearer:abcdefgh",
+        )
+
+    def test_empty_header_returns_anonymous(self):
+        self.assertEqual(audit._extract_principal(""), "anonymous")
+
+    def test_none_like_empty_returns_anonymous(self):
+        self.assertEqual(audit._extract_principal(None), "anonymous")
+
+    def test_short_token_uses_full_token(self):
+        self.assertEqual(audit._extract_principal("Bearer abc"), "bearer:abc")
+
+
+class TestAuditMiddlewareIntegration(unittest.TestCase):
+    """Integration tests: AuditMiddleware emits records via the full MCP tool 
path."""
+
+    def setUp(self):
+        logging.disable(logging.NOTSET)
+        self.log_records = []
+        self.handler = _CapturingHandler(self.log_records)
+        audit_logger = logging.getLogger("gravitino.mcp.audit")
+        audit_logger.addHandler(self.handler)
+        audit_logger.setLevel(logging.INFO)
+        audit_logger.propagate = False
+
+        RESTClientFactory.set_rest_client(MockOperation)
+        server = GravitinoMCPServer(Setting("mock_metalake"))
+        self.mcp = server.mcp
+
+    def tearDown(self):
+        logging.disable(logging.INFO)
+        audit_logger = logging.getLogger("gravitino.mcp.audit")
+        audit_logger.removeHandler(self.handler)
+        audit_logger.propagate = True
+        # Restore original REST client so other tests are not affected.
+        RESTClientFactory.set_rest_client(PlainRESTClientOperation)
+
+    def test_successful_tool_call_emits_allow_record(self):
+        """A successful tool call produces an audit record with 
outcome=allow."""
+
+        async def _run():
+            async with Client(self.mcp) as client:
+                await client.call_tool("get_list_of_catalogs")
+
+        asyncio.run(_run())
+
+        self.assertEqual(len(self.log_records), 1)
+        record = json.loads(self.log_records[0])
+        self.assertEqual(record["tool"], "get_list_of_catalogs")
+        self.assertEqual(record["outcome"], "allow")
+        self.assertEqual(record["principal"], "anonymous")
+
+    def test_failed_tool_call_emits_deny_record(self):
+        """A tool call that raises an exception produces an audit record with 
outcome=deny."""
+
+        class FailingOperation(MockOperation):
+            def as_catalog_operation(self):
+                return _FailingCatalogOperation()
+
+        RESTClientFactory.set_rest_client(FailingOperation)
+        server = GravitinoMCPServer(Setting("mock_metalake"))
+
+        async def _run():
+            async with Client(server.mcp) as client:
+                try:
+                    await client.call_tool("get_list_of_catalogs")
+                except Exception:
+                    pass
+
+        asyncio.run(_run())
+
+        deny_records = [
+            json.loads(r) for r in self.log_records if '"deny"' in r
+        ]
+        self.assertTrue(len(deny_records) >= 1)
+        self.assertEqual(deny_records[0]["tool"], "get_list_of_catalogs")
+        self.assertEqual(deny_records[0]["outcome"], "deny")
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+class _CapturingHandler(logging.Handler):
+    """Logging handler that stores formatted messages in a list."""
+
+    def __init__(self, records: list):
+        super().__init__()
+        self._records = records
+
+    def emit(self, record: logging.LogRecord) -> None:
+        self._records.append(self.format(record))
+
+
+class _FailingCatalogOperation:
+    async def get_list_of_catalogs(self) -> str:
+        raise GravitinoException("Error code: 1003, Error type: FORBIDDEN")

Reply via email to