diff --git a/core/server.py b/core/server.py index 97d1be6..7b68848 100644 --- a/core/server.py +++ b/core/server.py @@ -1,10 +1,12 @@ +import hashlib import logging import os -from typing import List, Optional +from typing import Callable, List, Optional from importlib import metadata from fastapi.responses import HTMLResponse, JSONResponse, FileResponse from starlette.applications import Starlette +from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.middleware import Middleware @@ -38,6 +40,38 @@ _legacy_callback_registered = False session_middleware = Middleware(MCPSessionMiddleware) +def _compute_scope_fingerprint() -> str: + """Compute a short hash of the current scope configuration for cache-busting.""" + scopes_str = ",".join(sorted(get_current_scopes())) + return hashlib.sha256(scopes_str.encode()).hexdigest()[:12] + + +class OAuthMetadataCacheBustMiddleware(BaseHTTPMiddleware): + """Override the upstream 1-hour Cache-Control on OAuth discovery endpoints. + + The MCP SDK sets ``Cache-Control: public, max-age=3600`` on the + ``.well-known`` metadata responses. When the server is restarted with a + different ``--permissions`` or ``--read-only`` configuration, browsers / + MCP clients can serve stale discovery docs that advertise the wrong + scopes, causing the OAuth flow to silently fail. + + This middleware replaces the cache header with ``no-store`` and adds an + ``ETag`` derived from the current scope set so that intermediary caches + that *do* store the response will still invalidate on config change. + """ + + def __init__(self, app: Starlette, scope_fingerprint: str) -> None: + super().__init__(app) + self._etag = f'"{scope_fingerprint}"' + + async def dispatch(self, request: Request, call_next: Callable): + response = await call_next(request) + if request.url.path.startswith("/.well-known/"): + response.headers["Cache-Control"] = "no-store, must-revalidate" + response.headers["ETag"] = self._etag + return response + + # Custom FastMCP that adds secure middleware stack for OAuth 2.1 class SecureFastMCP(FastMCP): def http_app(self, **kwargs) -> "Starlette": @@ -48,6 +82,13 @@ class SecureFastMCP(FastMCP): # Session Management - extracts session info for MCP context app.user_middleware.insert(0, session_middleware) + # Prevent browser caching of OAuth discovery endpoints across config changes + fingerprint = _compute_scope_fingerprint() + app.user_middleware.insert( + 0, + Middleware(OAuthMetadataCacheBustMiddleware, scope_fingerprint=fingerprint), + ) + # Rebuild middleware stack app.middleware_stack = app.build_middleware_stack() logger.info("Added middleware stack: Session Management") diff --git a/main.py b/main.py index 8dc28be..d5c288d 100644 --- a/main.py +++ b/main.py @@ -108,6 +108,13 @@ def resolve_permissions_mode_selection( return tier_services, set(tier_tools) +def narrow_permissions_to_services( + permissions: dict[str, str], services: list[str] +) -> dict[str, str]: + """Restrict permission entries to the provided service list order.""" + return {service: permissions[service] for service in services if service in permissions} + + def main(): """ Main entry point for the Google Workspace MCP server. @@ -199,6 +206,13 @@ def main(): file=sys.stderr, ) sys.exit(1) + if args.permissions and args.tools is not None: + print( + "Error: --permissions and --tools cannot be combined. " + "Select services via --permissions (optionally with --tool-tier).", + file=sys.stderr, + ) + sys.exit(1) # Set port and base URI once for reuse throughout the function port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000))) @@ -315,7 +329,6 @@ def main(): except ValueError as e: print(f"Error: {e}", file=sys.stderr) sys.exit(1) - set_permissions(perms) # Permissions implicitly defines which services to load tools_to_import = list(perms.keys()) set_enabled_tool_names(None) @@ -327,12 +340,14 @@ def main(): tools_to_import, args.tool_tier ) set_enabled_tool_names(tier_tool_filter) + perms = narrow_permissions_to_services(perms, tools_to_import) except Exception as e: print( f"Error loading tools for tier '{args.tool_tier}': {e}", file=sys.stderr, ) sys.exit(1) + set_permissions(perms) elif args.tool_tier is not None: # Use tier-based tool selection, optionally filtered by services try: diff --git a/tests/test_scopes.py b/tests/test_scopes.py index 43448b1..502df3d 100644 --- a/tests/test_scopes.py +++ b/tests/test_scopes.py @@ -12,6 +12,7 @@ import os sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from auth.scopes import ( + BASE_SCOPES, CALENDAR_READONLY_SCOPE, CALENDAR_SCOPE, CONTACTS_READONLY_SCOPE, @@ -31,6 +32,8 @@ from auth.scopes import ( has_required_scopes, set_read_only, ) +from auth.permissions import get_scopes_for_permission, set_permissions +import auth.permissions as permissions_module class TestDocsScopes: @@ -195,3 +198,34 @@ class TestHasRequiredScopes: available = [GMAIL_MODIFY_SCOPE] required = [GMAIL_READONLY_SCOPE, DRIVE_READONLY_SCOPE] assert not has_required_scopes(available, required) + + +class TestGranularPermissionsScopes: + """Tests for granular permissions scope generation path.""" + + def setup_method(self): + set_read_only(False) + permissions_module._PERMISSIONS = None + + def teardown_method(self): + set_read_only(False) + permissions_module._PERMISSIONS = None + + def test_permissions_mode_returns_base_plus_permission_scopes(self): + set_permissions({"gmail": "send", "drive": "readonly"}) + scopes = get_scopes_for_tools(["calendar"]) # ignored in permissions mode + + expected = set(BASE_SCOPES) + expected.update(get_scopes_for_permission("gmail", "send")) + expected.update(get_scopes_for_permission("drive", "readonly")) + assert set(scopes) == expected + + def test_permissions_mode_overrides_read_only_and_full_maps(self): + set_read_only(True) + without_permissions = get_scopes_for_tools(["drive"]) + assert DRIVE_READONLY_SCOPE in without_permissions + + set_permissions({"gmail": "readonly"}) + with_permissions = get_scopes_for_tools(["drive"]) + assert GMAIL_READONLY_SCOPE in with_permissions + assert DRIVE_READONLY_SCOPE not in with_permissions