This is an automated email from the ASF dual-hosted git repository. yuqi1129 pushed a commit to branch feat/mcp-governance-task3-6 in repository https://gitbox.apache.org/repos/asf/gravitino.git
commit e135d011d20ab2717e6ddf9f0d40f49a0c518409 Author: yuqi <[email protected]> AuthorDate: Wed Jun 10 20:30:52 2026 +0800 [#11568] feat(mcp-server): structured per-tool-call audit logging attributed to principal - core/audit.py: AuditLogger with emit() writing JSON records to gravitino.mcp.audit logger. Fields: timestamp, principal (bearer:<8-char-prefix> or anonymous), tool, outcome, error_type. - server.py: AuditMiddleware overrides on_call_tool; extracts principal from HTTP Authorization header via get_http_request() (falls back to anonymous in stdio mode); emits allow on success and deny on exception. - main.py: configure gravitino-mcp-audit.log file handler for the audit logger. - tests/unit/test_audit.py: 9 tests covering emit structure, principal extraction, middleware allow/deny integration. --- mcp-server/mcp_server/core/audit.py | 64 ++++++++++++ mcp-server/mcp_server/main.py | 6 ++ mcp-server/mcp_server/server.py | 50 +++++++++- mcp-server/tests/unit/test_audit.py | 189 ++++++++++++++++++++++++++++++++++++ 4 files changed, 308 insertions(+), 1 deletion(-) diff --git a/mcp-server/mcp_server/core/audit.py b/mcp-server/mcp_server/core/audit.py new file mode 100644 index 0000000000..f65b71cdff --- /dev/null +++ b/mcp-server/mcp_server/core/audit.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import logging +from datetime import datetime, timezone + +_audit_logger = logging.getLogger("gravitino.mcp.audit") + + +def _extract_principal(authorization: str) -> str: + """Derive a display principal from a raw Authorization header value. + + - "Bearer <token>" → "bearer:<first-8-chars-of-token>" + - empty / missing → "anonymous" + """ + if not authorization: + return "anonymous" + parts = authorization.split() + if len(parts) == 2 and parts[0].lower() == "bearer": + token = parts[1] + return f"bearer:{token[:8]}" + return "anonymous" + + +def emit( + *, + principal: str, + tool: str, + outcome: str, + error_type: str = "", +) -> None: + """Write one structured JSON audit record to the audit logger. + + Args: + principal: Identity derived from the request (e.g. "bearer:abc12345" or "anonymous"). + tool: MCP tool name that was invoked. + outcome: "allow" for successful calls, "deny" for authorization failures. + error_type: Exception class name when outcome is "deny", empty otherwise. + """ + record = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "principal": principal, + "tool": tool, + "outcome": outcome, + } + if error_type: + record["error_type"] = error_type + + _audit_logger.info(json.dumps(record)) diff --git a/mcp-server/mcp_server/main.py b/mcp-server/mcp_server/main.py index b1200b3180..ca60d8930f 100644 --- a/mcp-server/mcp_server/main.py +++ b/mcp-server/mcp_server/main.py @@ -46,6 +46,12 @@ def _init_logging(setting: Setting): format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) + # Separate file handler for structured audit records (one JSON line per entry). + audit_handler = logging.FileHandler("gravitino-mcp-audit.log") + audit_handler.setLevel(logging.INFO) + audit_handler.setFormatter(logging.Formatter("%(message)s")) + logging.getLogger("gravitino.mcp.audit").addHandler(audit_handler) + logging.getLogger("gravitino.mcp.audit").propagate = False def _comma_separated_set(value) -> set: diff --git a/mcp-server/mcp_server/server.py b/mcp-server/mcp_server/server.py index 9cce5f556b..8d66fec936 100644 --- a/mcp-server/mcp_server/server.py +++ b/mcp-server/mcp_server/server.py @@ -16,23 +16,71 @@ # under the License. import asyncio +import contextlib import logging from contextlib import asynccontextmanager from typing import AsyncIterator from urllib.parse import urlparse +import mcp.types as mt from fastmcp import FastMCP from fastmcp.server.middleware.error_handling import ErrorHandlingMiddleware from fastmcp.server.middleware.logging import ( LoggingMiddleware, ) +from fastmcp.server.middleware.middleware import ( + CallNext, + Middleware, + MiddlewareContext, +) from fastmcp.server.middleware.timing import TimingMiddleware +from fastmcp.tools.base import ToolResult +from mcp_server.core import audit from mcp_server.core.context import GravitinoContext from mcp_server.core.setting import Setting from mcp_server.tools import load_tools +def _get_principal_from_request() -> str: + """Extract the principal from the current HTTP request's Authorization header. + + Returns "anonymous" in stdio mode (no HTTP request) or when no token is present. + """ + try: + from fastmcp.server.dependencies import get_http_request + + request = get_http_request() + authorization = request.headers.get("authorization", "") + return audit._extract_principal(authorization) + except Exception: # noqa: BLE001 – stdio mode or no request context + return "anonymous" + + +class AuditMiddleware(Middleware): + """Emit a structured audit record for every tool invocation.""" + + async def on_call_tool( + self, + context: MiddlewareContext[mt.CallToolRequestParams], + call_next: CallNext[mt.CallToolRequestParams, ToolResult], + ) -> ToolResult: + tool_name = context.message.name if context.message else "unknown" + principal = _get_principal_from_request() + try: + result = await call_next(context) + audit.emit(principal=principal, tool=tool_name, outcome="allow") + return result + except Exception as exc: + audit.emit( + principal=principal, + tool=tool_name, + outcome="deny", + error_type=type(exc).__name__, + ) + raise + + def _create_lifespan_manager(gravitino_context: GravitinoContext): @asynccontextmanager @@ -56,10 +104,10 @@ def _create_gravitino_mcp(setting: Setting) -> FastMCP: lifespan=_create_lifespan_manager(GravitinoContext(setting)), ) + mcp.add_middleware(AuditMiddleware()) mcp.add_middleware( LoggingMiddleware(include_payloads=True, max_payload_length=1000) ) - mcp.add_middleware(TimingMiddleware()) mcp.add_middleware( ErrorHandlingMiddleware( diff --git a/mcp-server/tests/unit/test_audit.py b/mcp-server/tests/unit/test_audit.py new file mode 100644 index 0000000000..37cdef29cb --- /dev/null +++ b/mcp-server/tests/unit/test_audit.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import json +import logging +import unittest +from unittest.mock import patch + +from fastmcp import Client + +from mcp_server.client.factory import RESTClientFactory +from mcp_server.client.plain.exception import GravitinoException +from mcp_server.client.plain.plain_rest_client_operation import PlainRESTClientOperation +from mcp_server.core import audit +from mcp_server.core.setting import Setting +from mcp_server.server import GravitinoMCPServer +from tests.unit.tools import MockOperation + + +class TestAuditEmit(unittest.TestCase): + """Unit tests for the audit.emit() function.""" + + def setUp(self): + # tests/unit/tools/__init__.py calls logging.disable(logging.INFO). + # Re-enable here so audit records flow through; restore on teardown. + logging.disable(logging.NOTSET) + self.log_records = [] + self.handler = _CapturingHandler(self.log_records) + audit_logger = logging.getLogger("gravitino.mcp.audit") + audit_logger.addHandler(self.handler) + audit_logger.setLevel(logging.INFO) + audit_logger.propagate = False + + def tearDown(self): + logging.disable(logging.INFO) + audit_logger = logging.getLogger("gravitino.mcp.audit") + audit_logger.removeHandler(self.handler) + audit_logger.propagate = True + + def test_allow_record_structure(self): + """emit() writes a JSON record with all required fields on allow.""" + audit.emit(principal="bearer:abc12345", tool="list_catalogs", outcome="allow") + + self.assertEqual(len(self.log_records), 1) + record = json.loads(self.log_records[0]) + self.assertEqual(record["principal"], "bearer:abc12345") + self.assertEqual(record["tool"], "list_catalogs") + self.assertEqual(record["outcome"], "allow") + self.assertIn("timestamp", record) + self.assertNotIn("error_type", record) + + def test_deny_record_includes_error_type(self): + """emit() includes error_type in the record when outcome is deny.""" + audit.emit( + principal="bearer:xyz99999", + tool="create_tag", + outcome="deny", + error_type="GravitinoException", + ) + + record = json.loads(self.log_records[0]) + self.assertEqual(record["outcome"], "deny") + self.assertEqual(record["error_type"], "GravitinoException") + + def test_anonymous_principal(self): + """emit() works with anonymous principal.""" + audit.emit(principal="anonymous", tool="get_list_of_catalogs", outcome="allow") + record = json.loads(self.log_records[0]) + self.assertEqual(record["principal"], "anonymous") + + +class TestExtractPrincipal(unittest.TestCase): + """Unit tests for audit._extract_principal().""" + + def test_bearer_token_truncated_to_8_chars(self): + self.assertEqual( + audit._extract_principal("Bearer abcdefghijklmnop"), + "bearer:abcdefgh", + ) + + def test_empty_header_returns_anonymous(self): + self.assertEqual(audit._extract_principal(""), "anonymous") + + def test_none_like_empty_returns_anonymous(self): + self.assertEqual(audit._extract_principal(None), "anonymous") + + def test_short_token_uses_full_token(self): + self.assertEqual(audit._extract_principal("Bearer abc"), "bearer:abc") + + +class TestAuditMiddlewareIntegration(unittest.TestCase): + """Integration tests: AuditMiddleware emits records via the full MCP tool path.""" + + def setUp(self): + logging.disable(logging.NOTSET) + self.log_records = [] + self.handler = _CapturingHandler(self.log_records) + audit_logger = logging.getLogger("gravitino.mcp.audit") + audit_logger.addHandler(self.handler) + audit_logger.setLevel(logging.INFO) + audit_logger.propagate = False + + RESTClientFactory.set_rest_client(MockOperation) + server = GravitinoMCPServer(Setting("mock_metalake")) + self.mcp = server.mcp + + def tearDown(self): + logging.disable(logging.INFO) + audit_logger = logging.getLogger("gravitino.mcp.audit") + audit_logger.removeHandler(self.handler) + audit_logger.propagate = True + # Restore original REST client so other tests are not affected. + RESTClientFactory.set_rest_client(PlainRESTClientOperation) + + def test_successful_tool_call_emits_allow_record(self): + """A successful tool call produces an audit record with outcome=allow.""" + + async def _run(): + async with Client(self.mcp) as client: + await client.call_tool("get_list_of_catalogs") + + asyncio.run(_run()) + + self.assertEqual(len(self.log_records), 1) + record = json.loads(self.log_records[0]) + self.assertEqual(record["tool"], "get_list_of_catalogs") + self.assertEqual(record["outcome"], "allow") + self.assertEqual(record["principal"], "anonymous") + + def test_failed_tool_call_emits_deny_record(self): + """A tool call that raises an exception produces an audit record with outcome=deny.""" + + class FailingOperation(MockOperation): + def as_catalog_operation(self): + return _FailingCatalogOperation() + + RESTClientFactory.set_rest_client(FailingOperation) + server = GravitinoMCPServer(Setting("mock_metalake")) + + async def _run(): + async with Client(server.mcp) as client: + try: + await client.call_tool("get_list_of_catalogs") + except Exception: + pass + + asyncio.run(_run()) + + deny_records = [ + json.loads(r) for r in self.log_records if '"deny"' in r + ] + self.assertTrue(len(deny_records) >= 1) + self.assertEqual(deny_records[0]["tool"], "get_list_of_catalogs") + self.assertEqual(deny_records[0]["outcome"], "deny") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _CapturingHandler(logging.Handler): + """Logging handler that stores formatted messages in a list.""" + + def __init__(self, records: list): + super().__init__() + self._records = records + + def emit(self, record: logging.LogRecord) -> None: + self._records.append(self.format(record)) + + +class _FailingCatalogOperation: + async def get_list_of_catalogs(self) -> str: + raise GravitinoException("Error code: 1003, Error type: FORBIDDEN")
