cachebusting for oauth endpoints, more tests, startup check for perms

This commit is contained in:
Taylor Wilsdon
2026-02-28 11:40:29 -04:00
parent f2986dcf2f
commit edf9e94829
3 changed files with 92 additions and 2 deletions

View File

@@ -1,10 +1,12 @@
import hashlib
import logging import logging
import os import os
from typing import List, Optional from typing import Callable, List, Optional
from importlib import metadata from importlib import metadata
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.middleware import Middleware from starlette.middleware import Middleware
@@ -38,6 +40,38 @@ _legacy_callback_registered = False
session_middleware = Middleware(MCPSessionMiddleware) session_middleware = Middleware(MCPSessionMiddleware)
def _compute_scope_fingerprint() -> str:
"""Compute a short hash of the current scope configuration for cache-busting."""
scopes_str = ",".join(sorted(get_current_scopes()))
return hashlib.sha256(scopes_str.encode()).hexdigest()[:12]
class OAuthMetadataCacheBustMiddleware(BaseHTTPMiddleware):
"""Override the upstream 1-hour Cache-Control on OAuth discovery endpoints.
The MCP SDK sets ``Cache-Control: public, max-age=3600`` on the
``.well-known`` metadata responses. When the server is restarted with a
different ``--permissions`` or ``--read-only`` configuration, browsers /
MCP clients can serve stale discovery docs that advertise the wrong
scopes, causing the OAuth flow to silently fail.
This middleware replaces the cache header with ``no-store`` and adds an
``ETag`` derived from the current scope set so that intermediary caches
that *do* store the response will still invalidate on config change.
"""
def __init__(self, app: Starlette, scope_fingerprint: str) -> None:
super().__init__(app)
self._etag = f'"{scope_fingerprint}"'
async def dispatch(self, request: Request, call_next: Callable):
response = await call_next(request)
if request.url.path.startswith("/.well-known/"):
response.headers["Cache-Control"] = "no-store, must-revalidate"
response.headers["ETag"] = self._etag
return response
# Custom FastMCP that adds secure middleware stack for OAuth 2.1 # Custom FastMCP that adds secure middleware stack for OAuth 2.1
class SecureFastMCP(FastMCP): class SecureFastMCP(FastMCP):
def http_app(self, **kwargs) -> "Starlette": def http_app(self, **kwargs) -> "Starlette":
@@ -48,6 +82,13 @@ class SecureFastMCP(FastMCP):
# Session Management - extracts session info for MCP context # Session Management - extracts session info for MCP context
app.user_middleware.insert(0, session_middleware) app.user_middleware.insert(0, session_middleware)
# Prevent browser caching of OAuth discovery endpoints across config changes
fingerprint = _compute_scope_fingerprint()
app.user_middleware.insert(
0,
Middleware(OAuthMetadataCacheBustMiddleware, scope_fingerprint=fingerprint),
)
# Rebuild middleware stack # Rebuild middleware stack
app.middleware_stack = app.build_middleware_stack() app.middleware_stack = app.build_middleware_stack()
logger.info("Added middleware stack: Session Management") logger.info("Added middleware stack: Session Management")

17
main.py
View File

@@ -108,6 +108,13 @@ def resolve_permissions_mode_selection(
return tier_services, set(tier_tools) return tier_services, set(tier_tools)
def narrow_permissions_to_services(
permissions: dict[str, str], services: list[str]
) -> dict[str, str]:
"""Restrict permission entries to the provided service list order."""
return {service: permissions[service] for service in services if service in permissions}
def main(): def main():
""" """
Main entry point for the Google Workspace MCP server. Main entry point for the Google Workspace MCP server.
@@ -199,6 +206,13 @@ def main():
file=sys.stderr, file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
if args.permissions and args.tools is not None:
print(
"Error: --permissions and --tools cannot be combined. "
"Select services via --permissions (optionally with --tool-tier).",
file=sys.stderr,
)
sys.exit(1)
# Set port and base URI once for reuse throughout the function # Set port and base URI once for reuse throughout the function
port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000))) port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)))
@@ -315,7 +329,6 @@ def main():
except ValueError as e: except ValueError as e:
print(f"Error: {e}", file=sys.stderr) print(f"Error: {e}", file=sys.stderr)
sys.exit(1) sys.exit(1)
set_permissions(perms)
# Permissions implicitly defines which services to load # Permissions implicitly defines which services to load
tools_to_import = list(perms.keys()) tools_to_import = list(perms.keys())
set_enabled_tool_names(None) set_enabled_tool_names(None)
@@ -327,12 +340,14 @@ def main():
tools_to_import, args.tool_tier tools_to_import, args.tool_tier
) )
set_enabled_tool_names(tier_tool_filter) set_enabled_tool_names(tier_tool_filter)
perms = narrow_permissions_to_services(perms, tools_to_import)
except Exception as e: except Exception as e:
print( print(
f"Error loading tools for tier '{args.tool_tier}': {e}", f"Error loading tools for tier '{args.tool_tier}': {e}",
file=sys.stderr, file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
set_permissions(perms)
elif args.tool_tier is not None: elif args.tool_tier is not None:
# Use tier-based tool selection, optionally filtered by services # Use tier-based tool selection, optionally filtered by services
try: try:

View File

@@ -12,6 +12,7 @@ import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from auth.scopes import ( from auth.scopes import (
BASE_SCOPES,
CALENDAR_READONLY_SCOPE, CALENDAR_READONLY_SCOPE,
CALENDAR_SCOPE, CALENDAR_SCOPE,
CONTACTS_READONLY_SCOPE, CONTACTS_READONLY_SCOPE,
@@ -31,6 +32,8 @@ from auth.scopes import (
has_required_scopes, has_required_scopes,
set_read_only, set_read_only,
) )
from auth.permissions import get_scopes_for_permission, set_permissions
import auth.permissions as permissions_module
class TestDocsScopes: class TestDocsScopes:
@@ -195,3 +198,34 @@ class TestHasRequiredScopes:
available = [GMAIL_MODIFY_SCOPE] available = [GMAIL_MODIFY_SCOPE]
required = [GMAIL_READONLY_SCOPE, DRIVE_READONLY_SCOPE] required = [GMAIL_READONLY_SCOPE, DRIVE_READONLY_SCOPE]
assert not has_required_scopes(available, required) assert not has_required_scopes(available, required)
class TestGranularPermissionsScopes:
"""Tests for granular permissions scope generation path."""
def setup_method(self):
set_read_only(False)
permissions_module._PERMISSIONS = None
def teardown_method(self):
set_read_only(False)
permissions_module._PERMISSIONS = None
def test_permissions_mode_returns_base_plus_permission_scopes(self):
set_permissions({"gmail": "send", "drive": "readonly"})
scopes = get_scopes_for_tools(["calendar"]) # ignored in permissions mode
expected = set(BASE_SCOPES)
expected.update(get_scopes_for_permission("gmail", "send"))
expected.update(get_scopes_for_permission("drive", "readonly"))
assert set(scopes) == expected
def test_permissions_mode_overrides_read_only_and_full_maps(self):
set_read_only(True)
without_permissions = get_scopes_for_tools(["drive"])
assert DRIVE_READONLY_SCOPE in without_permissions
set_permissions({"gmail": "readonly"})
with_permissions = get_scopes_for_tools(["drive"])
assert GMAIL_READONLY_SCOPE in with_permissions
assert DRIVE_READONLY_SCOPE not in with_permissions