From c9facbff3dd829ad506c97492bd648e6ee8efd6d Mon Sep 17 00:00:00 2001 From: Taylor Wilsdon Date: Sun, 1 Mar 2026 17:59:45 -0500 Subject: [PATCH] refac --- core/server.py | 125 ++++++++++++++++--------------------------------- 1 file changed, 40 insertions(+), 85 deletions(-) diff --git a/core/server.py b/core/server.py index 51541e4..2931b34 100644 --- a/core/server.py +++ b/core/server.py @@ -6,8 +6,9 @@ from importlib import metadata from fastapi.responses import HTMLResponse, JSONResponse, FileResponse from starlette.applications import Starlette +from starlette.datastructures import MutableHeaders +from starlette.types import Scope, Receive, Send from starlette.requests import Request -from starlette.responses import Response from starlette.middleware import Middleware from fastmcp import FastMCP @@ -40,78 +41,44 @@ _legacy_callback_registered = False session_middleware = Middleware(MCPSessionMiddleware) +class WellKnownCacheControlMiddleware: + """Force no-cache headers for OAuth well-known discovery endpoints.""" + + def __init__(self, app): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + path = scope.get("path", "") + is_oauth_well_known = path.startswith( + "/.well-known/oauth-authorization-server" + ) or path.startswith("/.well-known/oauth-protected-resource") + if not is_oauth_well_known: + await self.app(scope, receive, send) + return + + async def send_with_no_cache_headers(message): + if message["type"] == "http.response.start": + headers = MutableHeaders(raw=message.setdefault("headers", [])) + headers["Cache-Control"] = "no-store, must-revalidate" + headers["ETag"] = f'"{_compute_scope_fingerprint()}"' + await send(message) + + await self.app(scope, receive, send_with_no_cache_headers) + + +well_known_cache_control_middleware = Middleware(WellKnownCacheControlMiddleware) + + def _compute_scope_fingerprint() -> str: """Compute a short hash of the current scope configuration for cache-busting.""" scopes_str = ",".join(sorted(get_current_scopes())) return hashlib.sha256(scopes_str.encode()).hexdigest()[:12] -def _wrap_well_known_endpoint(endpoint, etag: str): - """Wrap a well-known metadata endpoint to prevent browser caching. - - The MCP SDK hardcodes ``Cache-Control: public, max-age=3600`` on discovery - responses. When the server restarts with different ``--permissions`` or - ``--read-only`` flags, browsers / MCP clients serve stale metadata that - advertises the wrong scopes, causing OAuth to silently fail. - - The wrapper overrides the header to ``no-store`` and adds an ``ETag`` - derived from the current scope set so intermediary caches that ignore - ``no-store`` still see a fingerprint change. - - Handles both regular request handlers (``async def handler(request)``) - and ASGI app endpoints (e.g. ``CORSMiddleware``) which expect - ``(scope, receive, send)`` instead. - """ - import functools - import inspect - - # Determine if endpoint is a regular handler function or an ASGI app. - # Starlette's Route uses the same check (see starlette/routing.py). - endpoint_handler = endpoint - while isinstance(endpoint_handler, functools.partial): - endpoint_handler = endpoint_handler.func - is_regular_handler = inspect.isfunction(endpoint_handler) or inspect.ismethod( - endpoint_handler - ) - - if is_regular_handler: - - async def _no_cache_endpoint(request: Request) -> Response: - response = await endpoint(request) - response.headers["Cache-Control"] = "no-store, must-revalidate" - response.headers["ETag"] = etag - return response - - return _no_cache_endpoint - else: - # ASGI app (e.g. CORSMiddleware wrapping a handler). - # Invoke via the ASGI interface and capture the response so we can - # override cache headers before returning. - async def _no_cache_asgi_endpoint(request: Request) -> Response: - status_code = 200 - raw_headers: list[tuple[bytes, bytes]] = [] - body_parts: list[bytes] = [] - - async def send(message): - nonlocal status_code, raw_headers - if message["type"] == "http.response.start": - status_code = message["status"] - raw_headers = list(message.get("headers", [])) - elif message["type"] == "http.response.body": - body_parts.append(message.get("body", b"")) - - await endpoint(request.scope, request.receive, send) - - response = Response(content=b"".join(body_parts), status_code=status_code) - if raw_headers: - response.raw_headers = raw_headers - response.headers["Cache-Control"] = "no-store, must-revalidate" - response.headers["ETag"] = etag - return response - - return _no_cache_asgi_endpoint - - # Custom FastMCP that adds secure middleware stack for OAuth 2.1 class SecureFastMCP(FastMCP): def http_app(self, **kwargs) -> "Starlette": @@ -119,12 +86,16 @@ class SecureFastMCP(FastMCP): app = super().http_app(**kwargs) # Add middleware in order (first added = outermost layer) + app.user_middleware.insert(0, well_known_cache_control_middleware) + # Session Management - extracts session info for MCP context - app.user_middleware.insert(0, session_middleware) + app.user_middleware.insert(1, session_middleware) # Rebuild middleware stack app.middleware_stack = app.build_middleware_stack() - logger.info("Added middleware stack: Session Management") + logger.info( + "Added middleware stack: WellKnownCacheControl, Session Management" + ) return app @@ -461,22 +432,6 @@ def configure_server_for_http(): "OAuth 2.1 enabled using FastMCP GoogleProvider with protocol-level auth" ) - # Mount well-known routes with cache-busting headers. - # The MCP SDK hardcodes Cache-Control: public, max-age=3600 - # on discovery responses which causes stale-scope bugs when - # the server is restarted with a different --permissions config. - try: - scope_etag = f'"{_compute_scope_fingerprint()}"' - well_known_routes = provider.get_well_known_routes() - for route in well_known_routes: - logger.info(f"Mounting OAuth well-known route: {route.path}") - wrapped = _wrap_well_known_endpoint(route.endpoint, scope_etag) - server.custom_route(route.path, methods=list(route.methods))( - wrapped - ) - except Exception as e: - logger.warning(f"Could not mount well-known routes: {e}") - # Always set auth provider for token validation in middleware set_auth_provider(provider) _auth_provider = provider