diff --git a/core/server.py b/core/server.py index d21c9d5..15a55e0 100644 --- a/core/server.py +++ b/core/server.py @@ -6,8 +6,9 @@ from importlib import metadata from fastapi.responses import HTMLResponse, JSONResponse, FileResponse from starlette.applications import Starlette +from starlette.datastructures import MutableHeaders +from starlette.types import Scope, Receive, Send from starlette.requests import Request -from starlette.responses import Response from starlette.middleware import Middleware from fastmcp import FastMCP @@ -40,34 +41,47 @@ _legacy_callback_registered = False session_middleware = Middleware(MCPSessionMiddleware) +class WellKnownCacheControlMiddleware: + """Force no-cache headers for OAuth well-known discovery endpoints.""" + + def __init__(self, app): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + path = scope.get("path", "") + is_oauth_well_known = ( + path == "/.well-known/oauth-authorization-server" + or path.startswith("/.well-known/oauth-authorization-server/") + or path == "/.well-known/oauth-protected-resource" + or path.startswith("/.well-known/oauth-protected-resource/") + ) + if not is_oauth_well_known: + await self.app(scope, receive, send) + return + + async def send_with_no_cache_headers(message): + if message["type"] == "http.response.start": + headers = MutableHeaders(raw=message.setdefault("headers", [])) + headers["Cache-Control"] = "no-store, must-revalidate" + headers["ETag"] = f'"{_compute_scope_fingerprint()}"' + await send(message) + + await self.app(scope, receive, send_with_no_cache_headers) + + +well_known_cache_control_middleware = Middleware(WellKnownCacheControlMiddleware) + + def _compute_scope_fingerprint() -> str: """Compute a short hash of the current scope configuration for cache-busting.""" scopes_str = ",".join(sorted(get_current_scopes())) return hashlib.sha256(scopes_str.encode()).hexdigest()[:12] -def _wrap_well_known_endpoint(endpoint, etag: str): - """Wrap a well-known metadata endpoint to prevent browser caching. - - The MCP SDK hardcodes ``Cache-Control: public, max-age=3600`` on discovery - responses. When the server restarts with different ``--permissions`` or - ``--read-only`` flags, browsers / MCP clients serve stale metadata that - advertises the wrong scopes, causing OAuth to silently fail. - - The wrapper overrides the header to ``no-store`` and adds an ``ETag`` - derived from the current scope set so intermediary caches that ignore - ``no-store`` still see a fingerprint change. - """ - - async def _no_cache_endpoint(request: Request) -> Response: - response = await endpoint(request) - response.headers["Cache-Control"] = "no-store, must-revalidate" - response.headers["ETag"] = etag - return response - - return _no_cache_endpoint - - # Custom FastMCP that adds secure middleware stack for OAuth 2.1 class SecureFastMCP(FastMCP): def http_app(self, **kwargs) -> "Starlette": @@ -75,12 +89,16 @@ class SecureFastMCP(FastMCP): app = super().http_app(**kwargs) # Add middleware in order (first added = outermost layer) + app.user_middleware.insert(0, well_known_cache_control_middleware) + # Session Management - extracts session info for MCP context - app.user_middleware.insert(0, session_middleware) + app.user_middleware.insert(1, session_middleware) # Rebuild middleware stack app.middleware_stack = app.build_middleware_stack() - logger.info("Added middleware stack: Session Management") + logger.info( + "Added middleware stack: WellKnownCacheControl, Session Management" + ) return app @@ -417,22 +435,6 @@ def configure_server_for_http(): "OAuth 2.1 enabled using FastMCP GoogleProvider with protocol-level auth" ) - # Mount well-known routes with cache-busting headers. - # The MCP SDK hardcodes Cache-Control: public, max-age=3600 - # on discovery responses which causes stale-scope bugs when - # the server is restarted with a different --permissions config. - try: - scope_etag = f'"{_compute_scope_fingerprint()}"' - well_known_routes = provider.get_well_known_routes() - for route in well_known_routes: - logger.info(f"Mounting OAuth well-known route: {route.path}") - wrapped = _wrap_well_known_endpoint(route.endpoint, scope_etag) - server.custom_route(route.path, methods=list(route.methods))( - wrapped - ) - except Exception as e: - logger.warning(f"Could not mount well-known routes: {e}") - # Always set auth provider for token validation in middleware set_auth_provider(provider) _auth_provider = provider diff --git a/tests/core/test_well_known_cache_control_middleware.py b/tests/core/test_well_known_cache_control_middleware.py new file mode 100644 index 0000000..66e9b98 --- /dev/null +++ b/tests/core/test_well_known_cache_control_middleware.py @@ -0,0 +1,88 @@ +import importlib + +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.responses import Response +from starlette.routing import Route +from starlette.testclient import TestClient + + +def test_well_known_cache_control_middleware_rewrites_headers(): + from core.server import WellKnownCacheControlMiddleware, _compute_scope_fingerprint + + async def well_known_endpoint(request): + response = Response("ok") + response.headers["Cache-Control"] = "public, max-age=3600" + response.set_cookie("a", "1") + response.set_cookie("b", "2") + return response + + async def regular_endpoint(request): + response = Response("ok") + response.headers["Cache-Control"] = "public, max-age=3600" + return response + + app = Starlette( + routes=[ + Route("/.well-known/oauth-authorization-server", well_known_endpoint), + Route("/.well-known/oauth-authorization-server-extra", regular_endpoint), + Route("/health", regular_endpoint), + ], + middleware=[Middleware(WellKnownCacheControlMiddleware)], + ) + client = TestClient(app) + + well_known = client.get("/.well-known/oauth-authorization-server") + assert well_known.status_code == 200 + assert well_known.headers["cache-control"] == "no-store, must-revalidate" + assert well_known.headers["etag"] == f'"{_compute_scope_fingerprint()}"' + assert sorted(well_known.headers.get_list("set-cookie")) == sorted( + ["a=1; Path=/; SameSite=lax", "b=2; Path=/; SameSite=lax"] + ) + + regular = client.get("/health") + assert regular.status_code == 200 + assert regular.headers["cache-control"] == "public, max-age=3600" + assert "etag" not in regular.headers + + extra = client.get("/.well-known/oauth-authorization-server-extra") + assert extra.status_code == 200 + assert extra.headers["cache-control"] == "public, max-age=3600" + assert "etag" not in extra.headers + + +def test_configured_server_applies_no_cache_to_served_oauth_discovery_routes(monkeypatch): + monkeypatch.setenv("MCP_ENABLE_OAUTH21", "true") + monkeypatch.setenv("GOOGLE_OAUTH_CLIENT_ID", "dummy-client") + monkeypatch.setenv("GOOGLE_OAUTH_CLIENT_SECRET", "dummy-secret") + monkeypatch.setenv("WORKSPACE_MCP_BASE_URI", "http://localhost") + monkeypatch.setenv("WORKSPACE_MCP_PORT", "8000") + monkeypatch.delenv("WORKSPACE_EXTERNAL_URL", raising=False) + monkeypatch.setenv("EXTERNAL_OAUTH21_PROVIDER", "false") + + import core.server as core_server + from auth.oauth_config import reload_oauth_config + + reload_oauth_config() + core_server = importlib.reload(core_server) + core_server.set_transport_mode("streamable-http") + core_server.configure_server_for_http() + + app = core_server.server.http_app(transport="streamable-http", path="/mcp") + client = TestClient(app) + + authorization_server = client.get("/.well-known/oauth-authorization-server") + assert authorization_server.status_code == 200 + assert authorization_server.headers["cache-control"] == "no-store, must-revalidate" + assert authorization_server.headers["etag"].startswith('"') + assert authorization_server.headers["etag"].endswith('"') + + protected_resource = client.get("/.well-known/oauth-protected-resource/mcp") + assert protected_resource.status_code == 200 + assert protected_resource.headers["cache-control"] == "no-store, must-revalidate" + assert protected_resource.headers["etag"].startswith('"') + assert protected_resource.headers["etag"].endswith('"') + + # Ensure we did not create a shadow route at the wrong path. + wrong_path = client.get("/.well-known/oauth-protected-resource") + assert wrong_path.status_code == 404 diff --git a/uv.lock b/uv.lock index 859431b..2a68b9f 100644 --- a/uv.lock +++ b/uv.lock @@ -2044,7 +2044,7 @@ wheels = [ [[package]] name = "workspace-mcp" -version = "1.13.1" +version = "1.14.0" source = { editable = "." } dependencies = [ { name = "cryptography" },