Fix TypeError: CORSMiddleware.__call__() missing 2 required positional arguments
The _wrap_well_known_endpoint function assumed all route endpoints are regular request handlers (async def handler(request) -> Response). However, the MCP SDK's cors_middleware wraps handlers with CORSMiddleware, which is an ASGI app expecting (scope, receive, send). When the wrapper called `await endpoint(request)` on a CORSMiddleware instance, it passed only 1 argument instead of the required 3 ASGI args. The fix detects whether the endpoint is a regular handler function or an ASGI app (using the same inspect check as Starlette's Route constructor), and uses the appropriate calling convention: - Regular handlers: called as `await endpoint(request)` (existing behavior) - ASGI apps: invoked via the ASGI interface `await endpoint(scope, receive, send)` with response capture to apply cache-busting headers https://claude.ai/code/session_011S5zFTWRfKBJBUEanrhvQg
This commit is contained in:
@@ -57,15 +57,67 @@ def _wrap_well_known_endpoint(endpoint, etag: str):
|
|||||||
The wrapper overrides the header to ``no-store`` and adds an ``ETag``
|
The wrapper overrides the header to ``no-store`` and adds an ``ETag``
|
||||||
derived from the current scope set so intermediary caches that ignore
|
derived from the current scope set so intermediary caches that ignore
|
||||||
``no-store`` still see a fingerprint change.
|
``no-store`` still see a fingerprint change.
|
||||||
|
|
||||||
|
Handles both regular request handlers (``async def handler(request)``)
|
||||||
|
and ASGI app endpoints (e.g. ``CORSMiddleware``) which expect
|
||||||
|
``(scope, receive, send)`` instead.
|
||||||
"""
|
"""
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
|
||||||
async def _no_cache_endpoint(request: Request) -> Response:
|
# Determine if endpoint is a regular handler function or an ASGI app.
|
||||||
response = await endpoint(request)
|
# Starlette's Route uses the same check (see starlette/routing.py).
|
||||||
response.headers["Cache-Control"] = "no-store, must-revalidate"
|
endpoint_handler = endpoint
|
||||||
response.headers["ETag"] = etag
|
while isinstance(endpoint_handler, functools.partial):
|
||||||
return response
|
endpoint_handler = endpoint_handler.func
|
||||||
|
is_regular_handler = inspect.isfunction(endpoint_handler) or inspect.ismethod(
|
||||||
|
endpoint_handler
|
||||||
|
)
|
||||||
|
|
||||||
return _no_cache_endpoint
|
if is_regular_handler:
|
||||||
|
|
||||||
|
async def _no_cache_endpoint(request: Request) -> Response:
|
||||||
|
response = await endpoint(request)
|
||||||
|
response.headers["Cache-Control"] = "no-store, must-revalidate"
|
||||||
|
response.headers["ETag"] = etag
|
||||||
|
return response
|
||||||
|
|
||||||
|
return _no_cache_endpoint
|
||||||
|
else:
|
||||||
|
# ASGI app (e.g. CORSMiddleware wrapping a handler).
|
||||||
|
# Invoke via the ASGI interface and capture the response so we can
|
||||||
|
# override cache headers before returning.
|
||||||
|
async def _no_cache_asgi_endpoint(request: Request) -> Response:
|
||||||
|
status_code = 200
|
||||||
|
raw_headers: list[tuple[bytes, bytes]] = []
|
||||||
|
body_parts: list[bytes] = []
|
||||||
|
|
||||||
|
async def send(message):
|
||||||
|
nonlocal status_code, raw_headers
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
status_code = message["status"]
|
||||||
|
raw_headers = list(message.get("headers", []))
|
||||||
|
elif message["type"] == "http.response.body":
|
||||||
|
body_parts.append(message.get("body", b""))
|
||||||
|
|
||||||
|
await endpoint(request.scope, request._receive, send)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
(k.decode() if isinstance(k, bytes) else k): (
|
||||||
|
v.decode() if isinstance(v, bytes) else v
|
||||||
|
)
|
||||||
|
for k, v in raw_headers
|
||||||
|
}
|
||||||
|
response = Response(
|
||||||
|
content=b"".join(body_parts),
|
||||||
|
status_code=status_code,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
response.headers["Cache-Control"] = "no-store, must-revalidate"
|
||||||
|
response.headers["ETag"] = etag
|
||||||
|
return response
|
||||||
|
|
||||||
|
return _no_cache_asgi_endpoint
|
||||||
|
|
||||||
|
|
||||||
# Custom FastMCP that adds secure middleware stack for OAuth 2.1
|
# Custom FastMCP that adds secure middleware stack for OAuth 2.1
|
||||||
|
|||||||
Reference in New Issue
Block a user