This commit is contained in:
Taylor Wilsdon
2026-03-01 17:59:45 -05:00
parent 217d727a9d
commit c9facbff3d

View File

@@ -6,8 +6,9 @@ from importlib import metadata
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.datastructures import MutableHeaders
from starlette.types import Scope, Receive, Send
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response
from starlette.middleware import Middleware from starlette.middleware import Middleware
from fastmcp import FastMCP from fastmcp import FastMCP
@@ -40,78 +41,44 @@ _legacy_callback_registered = False
session_middleware = Middleware(MCPSessionMiddleware) session_middleware = Middleware(MCPSessionMiddleware)
class WellKnownCacheControlMiddleware:
"""Force no-cache headers for OAuth well-known discovery endpoints."""
def __init__(self, app):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
path = scope.get("path", "")
is_oauth_well_known = path.startswith(
"/.well-known/oauth-authorization-server"
) or path.startswith("/.well-known/oauth-protected-resource")
if not is_oauth_well_known:
await self.app(scope, receive, send)
return
async def send_with_no_cache_headers(message):
if message["type"] == "http.response.start":
headers = MutableHeaders(raw=message.setdefault("headers", []))
headers["Cache-Control"] = "no-store, must-revalidate"
headers["ETag"] = f'"{_compute_scope_fingerprint()}"'
await send(message)
await self.app(scope, receive, send_with_no_cache_headers)
well_known_cache_control_middleware = Middleware(WellKnownCacheControlMiddleware)
def _compute_scope_fingerprint() -> str: def _compute_scope_fingerprint() -> str:
"""Compute a short hash of the current scope configuration for cache-busting.""" """Compute a short hash of the current scope configuration for cache-busting."""
scopes_str = ",".join(sorted(get_current_scopes())) scopes_str = ",".join(sorted(get_current_scopes()))
return hashlib.sha256(scopes_str.encode()).hexdigest()[:12] return hashlib.sha256(scopes_str.encode()).hexdigest()[:12]
def _wrap_well_known_endpoint(endpoint, etag: str):
"""Wrap a well-known metadata endpoint to prevent browser caching.
The MCP SDK hardcodes ``Cache-Control: public, max-age=3600`` on discovery
responses. When the server restarts with different ``--permissions`` or
``--read-only`` flags, browsers / MCP clients serve stale metadata that
advertises the wrong scopes, causing OAuth to silently fail.
The wrapper overrides the header to ``no-store`` and adds an ``ETag``
derived from the current scope set so intermediary caches that ignore
``no-store`` still see a fingerprint change.
Handles both regular request handlers (``async def handler(request)``)
and ASGI app endpoints (e.g. ``CORSMiddleware``) which expect
``(scope, receive, send)`` instead.
"""
import functools
import inspect
# Determine if endpoint is a regular handler function or an ASGI app.
# Starlette's Route uses the same check (see starlette/routing.py).
endpoint_handler = endpoint
while isinstance(endpoint_handler, functools.partial):
endpoint_handler = endpoint_handler.func
is_regular_handler = inspect.isfunction(endpoint_handler) or inspect.ismethod(
endpoint_handler
)
if is_regular_handler:
async def _no_cache_endpoint(request: Request) -> Response:
response = await endpoint(request)
response.headers["Cache-Control"] = "no-store, must-revalidate"
response.headers["ETag"] = etag
return response
return _no_cache_endpoint
else:
# ASGI app (e.g. CORSMiddleware wrapping a handler).
# Invoke via the ASGI interface and capture the response so we can
# override cache headers before returning.
async def _no_cache_asgi_endpoint(request: Request) -> Response:
status_code = 200
raw_headers: list[tuple[bytes, bytes]] = []
body_parts: list[bytes] = []
async def send(message):
nonlocal status_code, raw_headers
if message["type"] == "http.response.start":
status_code = message["status"]
raw_headers = list(message.get("headers", []))
elif message["type"] == "http.response.body":
body_parts.append(message.get("body", b""))
await endpoint(request.scope, request.receive, send)
response = Response(content=b"".join(body_parts), status_code=status_code)
if raw_headers:
response.raw_headers = raw_headers
response.headers["Cache-Control"] = "no-store, must-revalidate"
response.headers["ETag"] = etag
return response
return _no_cache_asgi_endpoint
# Custom FastMCP that adds secure middleware stack for OAuth 2.1 # Custom FastMCP that adds secure middleware stack for OAuth 2.1
class SecureFastMCP(FastMCP): class SecureFastMCP(FastMCP):
def http_app(self, **kwargs) -> "Starlette": def http_app(self, **kwargs) -> "Starlette":
@@ -119,12 +86,16 @@ class SecureFastMCP(FastMCP):
app = super().http_app(**kwargs) app = super().http_app(**kwargs)
# Add middleware in order (first added = outermost layer) # Add middleware in order (first added = outermost layer)
app.user_middleware.insert(0, well_known_cache_control_middleware)
# Session Management - extracts session info for MCP context # Session Management - extracts session info for MCP context
app.user_middleware.insert(0, session_middleware) app.user_middleware.insert(1, session_middleware)
# Rebuild middleware stack # Rebuild middleware stack
app.middleware_stack = app.build_middleware_stack() app.middleware_stack = app.build_middleware_stack()
logger.info("Added middleware stack: Session Management") logger.info(
"Added middleware stack: WellKnownCacheControl, Session Management"
)
return app return app
@@ -461,22 +432,6 @@ def configure_server_for_http():
"OAuth 2.1 enabled using FastMCP GoogleProvider with protocol-level auth" "OAuth 2.1 enabled using FastMCP GoogleProvider with protocol-level auth"
) )
# Mount well-known routes with cache-busting headers.
# The MCP SDK hardcodes Cache-Control: public, max-age=3600
# on discovery responses which causes stale-scope bugs when
# the server is restarted with a different --permissions config.
try:
scope_etag = f'"{_compute_scope_fingerprint()}"'
well_known_routes = provider.get_well_known_routes()
for route in well_known_routes:
logger.info(f"Mounting OAuth well-known route: {route.path}")
wrapped = _wrap_well_known_endpoint(route.endpoint, scope_etag)
server.custom_route(route.path, methods=list(route.methods))(
wrapped
)
except Exception as e:
logger.warning(f"Could not mount well-known routes: {e}")
# Always set auth provider for token validation in middleware # Always set auth provider for token validation in middleware
set_auth_provider(provider) set_auth_provider(provider)
_auth_provider = provider _auth_provider = provider