Merge pull request #523 from jack-distl/claude/fix-cors-middleware-error-sPbo6
Fix TypeError: CORSMiddleware.__call__() missing 2 required positiona…
This commit is contained in:
@@ -6,8 +6,9 @@ from importlib import metadata
|
||||
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
|
||||
from starlette.applications import Starlette
|
||||
from starlette.datastructures import MutableHeaders
|
||||
from starlette.types import Scope, Receive, Send
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.middleware import Middleware
|
||||
|
||||
from fastmcp import FastMCP
|
||||
@@ -40,34 +41,47 @@ _legacy_callback_registered = False
|
||||
session_middleware = Middleware(MCPSessionMiddleware)
|
||||
|
||||
|
||||
class WellKnownCacheControlMiddleware:
|
||||
"""Force no-cache headers for OAuth well-known discovery endpoints."""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
path = scope.get("path", "")
|
||||
is_oauth_well_known = (
|
||||
path == "/.well-known/oauth-authorization-server"
|
||||
or path.startswith("/.well-known/oauth-authorization-server/")
|
||||
or path == "/.well-known/oauth-protected-resource"
|
||||
or path.startswith("/.well-known/oauth-protected-resource/")
|
||||
)
|
||||
if not is_oauth_well_known:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
async def send_with_no_cache_headers(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = MutableHeaders(raw=message.setdefault("headers", []))
|
||||
headers["Cache-Control"] = "no-store, must-revalidate"
|
||||
headers["ETag"] = f'"{_compute_scope_fingerprint()}"'
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_with_no_cache_headers)
|
||||
|
||||
|
||||
well_known_cache_control_middleware = Middleware(WellKnownCacheControlMiddleware)
|
||||
|
||||
|
||||
def _compute_scope_fingerprint() -> str:
|
||||
"""Compute a short hash of the current scope configuration for cache-busting."""
|
||||
scopes_str = ",".join(sorted(get_current_scopes()))
|
||||
return hashlib.sha256(scopes_str.encode()).hexdigest()[:12]
|
||||
|
||||
|
||||
def _wrap_well_known_endpoint(endpoint, etag: str):
|
||||
"""Wrap a well-known metadata endpoint to prevent browser caching.
|
||||
|
||||
The MCP SDK hardcodes ``Cache-Control: public, max-age=3600`` on discovery
|
||||
responses. When the server restarts with different ``--permissions`` or
|
||||
``--read-only`` flags, browsers / MCP clients serve stale metadata that
|
||||
advertises the wrong scopes, causing OAuth to silently fail.
|
||||
|
||||
The wrapper overrides the header to ``no-store`` and adds an ``ETag``
|
||||
derived from the current scope set so intermediary caches that ignore
|
||||
``no-store`` still see a fingerprint change.
|
||||
"""
|
||||
|
||||
async def _no_cache_endpoint(request: Request) -> Response:
|
||||
response = await endpoint(request)
|
||||
response.headers["Cache-Control"] = "no-store, must-revalidate"
|
||||
response.headers["ETag"] = etag
|
||||
return response
|
||||
|
||||
return _no_cache_endpoint
|
||||
|
||||
|
||||
# Custom FastMCP that adds secure middleware stack for OAuth 2.1
|
||||
class SecureFastMCP(FastMCP):
|
||||
def http_app(self, **kwargs) -> "Starlette":
|
||||
@@ -75,12 +89,16 @@ class SecureFastMCP(FastMCP):
|
||||
app = super().http_app(**kwargs)
|
||||
|
||||
# Add middleware in order (first added = outermost layer)
|
||||
app.user_middleware.insert(0, well_known_cache_control_middleware)
|
||||
|
||||
# Session Management - extracts session info for MCP context
|
||||
app.user_middleware.insert(0, session_middleware)
|
||||
app.user_middleware.insert(1, session_middleware)
|
||||
|
||||
# Rebuild middleware stack
|
||||
app.middleware_stack = app.build_middleware_stack()
|
||||
logger.info("Added middleware stack: Session Management")
|
||||
logger.info(
|
||||
"Added middleware stack: WellKnownCacheControl, Session Management"
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
@@ -417,22 +435,6 @@ def configure_server_for_http():
|
||||
"OAuth 2.1 enabled using FastMCP GoogleProvider with protocol-level auth"
|
||||
)
|
||||
|
||||
# Mount well-known routes with cache-busting headers.
|
||||
# The MCP SDK hardcodes Cache-Control: public, max-age=3600
|
||||
# on discovery responses which causes stale-scope bugs when
|
||||
# the server is restarted with a different --permissions config.
|
||||
try:
|
||||
scope_etag = f'"{_compute_scope_fingerprint()}"'
|
||||
well_known_routes = provider.get_well_known_routes()
|
||||
for route in well_known_routes:
|
||||
logger.info(f"Mounting OAuth well-known route: {route.path}")
|
||||
wrapped = _wrap_well_known_endpoint(route.endpoint, scope_etag)
|
||||
server.custom_route(route.path, methods=list(route.methods))(
|
||||
wrapped
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not mount well-known routes: {e}")
|
||||
|
||||
# Always set auth provider for token validation in middleware
|
||||
set_auth_provider(provider)
|
||||
_auth_provider = provider
|
||||
|
||||
88
tests/core/test_well_known_cache_control_middleware.py
Normal file
88
tests/core/test_well_known_cache_control_middleware.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import importlib
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Route
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
|
||||
def test_well_known_cache_control_middleware_rewrites_headers():
|
||||
from core.server import WellKnownCacheControlMiddleware, _compute_scope_fingerprint
|
||||
|
||||
async def well_known_endpoint(request):
|
||||
response = Response("ok")
|
||||
response.headers["Cache-Control"] = "public, max-age=3600"
|
||||
response.set_cookie("a", "1")
|
||||
response.set_cookie("b", "2")
|
||||
return response
|
||||
|
||||
async def regular_endpoint(request):
|
||||
response = Response("ok")
|
||||
response.headers["Cache-Control"] = "public, max-age=3600"
|
||||
return response
|
||||
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Route("/.well-known/oauth-authorization-server", well_known_endpoint),
|
||||
Route("/.well-known/oauth-authorization-server-extra", regular_endpoint),
|
||||
Route("/health", regular_endpoint),
|
||||
],
|
||||
middleware=[Middleware(WellKnownCacheControlMiddleware)],
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
well_known = client.get("/.well-known/oauth-authorization-server")
|
||||
assert well_known.status_code == 200
|
||||
assert well_known.headers["cache-control"] == "no-store, must-revalidate"
|
||||
assert well_known.headers["etag"] == f'"{_compute_scope_fingerprint()}"'
|
||||
assert sorted(well_known.headers.get_list("set-cookie")) == sorted(
|
||||
["a=1; Path=/; SameSite=lax", "b=2; Path=/; SameSite=lax"]
|
||||
)
|
||||
|
||||
regular = client.get("/health")
|
||||
assert regular.status_code == 200
|
||||
assert regular.headers["cache-control"] == "public, max-age=3600"
|
||||
assert "etag" not in regular.headers
|
||||
|
||||
extra = client.get("/.well-known/oauth-authorization-server-extra")
|
||||
assert extra.status_code == 200
|
||||
assert extra.headers["cache-control"] == "public, max-age=3600"
|
||||
assert "etag" not in extra.headers
|
||||
|
||||
|
||||
def test_configured_server_applies_no_cache_to_served_oauth_discovery_routes(monkeypatch):
|
||||
monkeypatch.setenv("MCP_ENABLE_OAUTH21", "true")
|
||||
monkeypatch.setenv("GOOGLE_OAUTH_CLIENT_ID", "dummy-client")
|
||||
monkeypatch.setenv("GOOGLE_OAUTH_CLIENT_SECRET", "dummy-secret")
|
||||
monkeypatch.setenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
|
||||
monkeypatch.setenv("WORKSPACE_MCP_PORT", "8000")
|
||||
monkeypatch.delenv("WORKSPACE_EXTERNAL_URL", raising=False)
|
||||
monkeypatch.setenv("EXTERNAL_OAUTH21_PROVIDER", "false")
|
||||
|
||||
import core.server as core_server
|
||||
from auth.oauth_config import reload_oauth_config
|
||||
|
||||
reload_oauth_config()
|
||||
core_server = importlib.reload(core_server)
|
||||
core_server.set_transport_mode("streamable-http")
|
||||
core_server.configure_server_for_http()
|
||||
|
||||
app = core_server.server.http_app(transport="streamable-http", path="/mcp")
|
||||
client = TestClient(app)
|
||||
|
||||
authorization_server = client.get("/.well-known/oauth-authorization-server")
|
||||
assert authorization_server.status_code == 200
|
||||
assert authorization_server.headers["cache-control"] == "no-store, must-revalidate"
|
||||
assert authorization_server.headers["etag"].startswith('"')
|
||||
assert authorization_server.headers["etag"].endswith('"')
|
||||
|
||||
protected_resource = client.get("/.well-known/oauth-protected-resource/mcp")
|
||||
assert protected_resource.status_code == 200
|
||||
assert protected_resource.headers["cache-control"] == "no-store, must-revalidate"
|
||||
assert protected_resource.headers["etag"].startswith('"')
|
||||
assert protected_resource.headers["etag"].endswith('"')
|
||||
|
||||
# Ensure we did not create a shadow route at the wrong path.
|
||||
wrong_path = client.get("/.well-known/oauth-protected-resource")
|
||||
assert wrong_path.status_code == 404
|
||||
Reference in New Issue
Block a user