source of truth for oauth2.1 enablement

This commit is contained in:
Taylor Wilsdon
2025-08-09 11:53:00 -04:00
parent 374dc9c3e7
commit 773645794a
6 changed files with 93 additions and 86 deletions

View File

@@ -282,8 +282,9 @@ This architecture enables any OAuth 2.1 compliant client to authenticate users t
</details> </details>
**For MCP Inspector**: No additional configuration needed with desktop OAuth client. **MCP Inspector**: No additional configuration needed with desktop OAuth client.
**Claude Code Inspector**: No additional configuration needed with desktop OAuth client.
### VS Code MCP Client Support ### VS Code MCP Client Support
@@ -304,17 +305,6 @@ The server includes native support for VS Code's MCP client:
} }
``` ```
**For VS Code**: No additional configuration needed with desktop OAuth client.
### Modular Architecture
The server uses a clean, modular architecture for maintainability and security with broad OAuth2.1 MCP Client support:
- **Centralized Configuration**: [`OAuthConfig`](auth/oauth_config.py) eliminates hardcoded values and provides environment-based configuration
- **Standardized Error Handling**: [`oauth_error_handling.py`](auth/oauth_error_handling.py) provides consistent error responses and input validation
- **Security-First Design**: Proper CORS handling, input sanitization, and comprehensive validation throughout
### Connect to Claude Desktop ### Connect to Claude Desktop
The server supports two transport modes: The server supports two transport modes:

View File

@@ -24,6 +24,7 @@ from pydantic import AnyHttpUrl
try: try:
from fastmcp.server.auth import RemoteAuthProvider from fastmcp.server.auth import RemoteAuthProvider
from fastmcp.server.auth.providers.jwt import JWTVerifier from fastmcp.server.auth.providers.jwt import JWTVerifier
REMOTEAUTHPROVIDER_AVAILABLE = True REMOTEAUTHPROVIDER_AVAILABLE = True
except ImportError: except ImportError:
REMOTEAUTHPROVIDER_AVAILABLE = False REMOTEAUTHPROVIDER_AVAILABLE = False
@@ -38,7 +39,7 @@ from auth.oauth_common_handlers import (
handle_oauth_protected_resource, handle_oauth_protected_resource,
handle_oauth_authorization_server, handle_oauth_authorization_server,
handle_oauth_client_config, handle_oauth_client_config,
handle_oauth_register handle_oauth_register,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -52,9 +53,6 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
- OAuth proxy endpoints for CORS workaround - OAuth proxy endpoints for CORS workaround
- Dynamic client registration support - Dynamic client registration support
- Session management with issuer tracking - Session management with issuer tracking
VS Code compatibility is now handled transparently by middleware,
eliminating the need for custom redirects and path handling.
""" """
def __init__(self): def __init__(self):
@@ -69,15 +67,19 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
self.port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000))) self.port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)))
if not self.client_id: if not self.client_id:
logger.error("GOOGLE_OAUTH_CLIENT_ID not set - OAuth 2.1 authentication will not work") logger.error(
raise ValueError("GOOGLE_OAUTH_CLIENT_ID environment variable is required for OAuth 2.1 authentication") "GOOGLE_OAUTH_CLIENT_ID not set - OAuth 2.1 authentication will not work"
)
raise ValueError(
"GOOGLE_OAUTH_CLIENT_ID environment variable is required for OAuth 2.1 authentication"
)
# Configure JWT verifier for Google tokens # Configure JWT verifier for Google tokens
token_verifier = JWTVerifier( token_verifier = JWTVerifier(
jwks_uri="https://www.googleapis.com/oauth2/v3/certs", jwks_uri="https://www.googleapis.com/oauth2/v3/certs",
issuer="https://accounts.google.com", issuer="https://accounts.google.com",
audience=self.client_id, # Always use actual client_id audience=self.client_id, # Always use actual client_id
algorithm="RS256" algorithm="RS256",
) )
# Initialize RemoteAuthProvider with base URL (no /mcp/ suffix) # Initialize RemoteAuthProvider with base URL (no /mcp/ suffix)
@@ -85,55 +87,72 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
super().__init__( super().__init__(
token_verifier=token_verifier, token_verifier=token_verifier,
authorization_servers=[AnyHttpUrl(f"{self.base_url}:{self.port}")], authorization_servers=[AnyHttpUrl(f"{self.base_url}:{self.port}")],
resource_server_url=f"{self.base_url}:{self.port}" resource_server_url=f"{self.base_url}:{self.port}",
) )
logger.debug("GoogleRemoteAuthProvider initialized with VS Code compatibility") logger.debug("GoogleRemoteAuthProvider")
def get_routes(self) -> List[Route]: def get_routes(self) -> List[Route]:
""" """
Add OAuth routes at canonical locations. Add OAuth routes at canonical locations.
VS Code compatibility is now handled transparently by middleware,
so we only need to register routes at their canonical locations.
""" """
# Get the standard OAuth protected resource routes from RemoteAuthProvider # Get the standard OAuth protected resource routes from RemoteAuthProvider
parent_routes = super().get_routes() parent_routes = super().get_routes()
# Filter out the parent's oauth-protected-resource route since we're replacing it # Filter out the parent's oauth-protected-resource route since we're replacing it
routes = [r for r in parent_routes if r.path != "/.well-known/oauth-protected-resource"] routes = [
r
for r in parent_routes
if r.path != "/.well-known/oauth-protected-resource"
]
# Add our custom OAuth discovery endpoint that returns /mcp/ as the resource # Add our custom OAuth discovery endpoint that returns /mcp/ as the resource
routes.append(Route( routes.append(
Route(
"/.well-known/oauth-protected-resource", "/.well-known/oauth-protected-resource",
handle_oauth_protected_resource, handle_oauth_protected_resource,
methods=["GET", "OPTIONS"] methods=["GET", "OPTIONS"],
)) )
)
routes.append(Route( routes.append(
Route(
"/.well-known/oauth-authorization-server", "/.well-known/oauth-authorization-server",
handle_oauth_authorization_server, handle_oauth_authorization_server,
methods=["GET", "OPTIONS"] methods=["GET", "OPTIONS"],
)) )
)
routes.append(Route( routes.append(
Route(
"/.well-known/oauth-client", "/.well-known/oauth-client",
handle_oauth_client_config, handle_oauth_client_config,
methods=["GET", "OPTIONS"] methods=["GET", "OPTIONS"],
)) )
)
# Add OAuth flow endpoints # Add OAuth flow endpoints
routes.append(Route("/oauth2/authorize", handle_oauth_authorize, methods=["GET", "OPTIONS"])) routes.append(
routes.append(Route("/oauth2/token", handle_proxy_token_exchange, methods=["POST", "OPTIONS"])) Route(
routes.append(Route("/oauth2/register", handle_oauth_register, methods=["POST", "OPTIONS"])) "/oauth2/authorize", handle_oauth_authorize, methods=["GET", "OPTIONS"]
)
)
routes.append(
Route(
"/oauth2/token",
handle_proxy_token_exchange,
methods=["POST", "OPTIONS"],
)
)
routes.append(
Route(
"/oauth2/register", handle_oauth_register, methods=["POST", "OPTIONS"]
)
)
logger.info(f"Registered {len(routes)} OAuth routes") logger.info(f"Registered {len(routes)} OAuth routes")
return routes return routes
async def verify_token(self, token: str) -> Optional[object]: async def verify_token(self, token: str) -> Optional[object]:
""" """
Override verify_token to handle Google OAuth access tokens. Override verify_token to handle Google OAuth access tokens.
@@ -143,22 +162,30 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
""" """
# Check if this is a Google OAuth access token (starts with ya29.) # Check if this is a Google OAuth access token (starts with ya29.)
if token.startswith("ya29."): if token.startswith("ya29."):
logger.debug("Detected Google OAuth access token, using tokeninfo verification") logger.debug(
"Detected Google OAuth access token, using tokeninfo verification"
)
try: try:
# Verify the access token using Google's tokeninfo endpoint # Verify the access token using Google's tokeninfo endpoint
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
url = f"https://oauth2.googleapis.com/tokeninfo?access_token={token}" url = (
f"https://oauth2.googleapis.com/tokeninfo?access_token={token}"
)
async with session.get(url) as response: async with session.get(url) as response:
if response.status != 200: if response.status != 200:
logger.error(f"Token verification failed: {response.status}") logger.error(
f"Token verification failed: {response.status}"
)
return None return None
token_info = await response.json() token_info = await response.json()
# Verify the token is for our client # Verify the token is for our client
if token_info.get("aud") != self.client_id: if token_info.get("aud") != self.client_id:
logger.error(f"Token audience mismatch: expected {self.client_id}, got {token_info.get('aud')}") logger.error(
f"Token audience mismatch: expected {self.client_id}, got {token_info.get('aud')}"
)
return None return None
# Check if token is expired # Check if token is expired
@@ -173,7 +200,9 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
# Calculate expires_at timestamp # Calculate expires_at timestamp
expires_in = int(token_info.get("expires_in", 0)) expires_in = int(token_info.get("expires_in", 0))
expires_at = int(time.time()) + expires_in if expires_in > 0 else 0 expires_at = (
int(time.time()) + expires_in if expires_in > 0 else 0
)
access_token = SimpleNamespace( access_token = SimpleNamespace(
claims={ claims={
@@ -188,12 +217,15 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
client_id=self.client_id, # Add client_id at top level client_id=self.client_id, # Add client_id at top level
# Add other required fields # Add other required fields
sub=token_info.get("sub", ""), sub=token_info.get("sub", ""),
email=token_info.get("email", "") email=token_info.get("email", ""),
) )
user_email = token_info.get("email") user_email = token_info.get("email")
if user_email: if user_email:
from auth.oauth21_session_store import get_oauth21_session_store from auth.oauth21_session_store import (
get_oauth21_session_store,
)
store = get_oauth21_session_store() store = get_oauth21_session_store()
session_id = f"google_{token_info.get('sub', 'unknown')}" session_id = f"google_{token_info.get('sub', 'unknown')}"
@@ -201,10 +233,13 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
mcp_session_id = None mcp_session_id = None
try: try:
from fastmcp.server.dependencies import get_context from fastmcp.server.dependencies import get_context
ctx = get_context() ctx = get_context()
if ctx and hasattr(ctx, 'session_id'): if ctx and hasattr(ctx, "session_id"):
mcp_session_id = ctx.session_id mcp_session_id = ctx.session_id
logger.debug(f"Binding MCP session {mcp_session_id} to user {user_email}") logger.debug(
f"Binding MCP session {mcp_session_id} to user {user_email}"
)
except Exception: except Exception:
pass pass
@@ -215,7 +250,7 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
scopes=access_token.scopes, scopes=access_token.scopes,
session_id=session_id, session_id=session_id,
mcp_session_id=mcp_session_id, mcp_session_id=mcp_session_id,
issuer="https://accounts.google.com" issuer="https://accounts.google.com",
) )
logger.info(f"Verified OAuth token: {user_email}") logger.info(f"Verified OAuth token: {user_email}")
@@ -236,6 +271,7 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
user_email = access_token.claims.get("email") user_email = access_token.claims.get("email")
if user_email: if user_email:
from auth.oauth21_session_store import get_oauth21_session_store from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store() store = get_oauth21_session_store()
session_id = f"google_{access_token.claims.get('sub', 'unknown')}" session_id = f"google_{access_token.claims.get('sub', 'unknown')}"
@@ -245,9 +281,11 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
access_token=token, access_token=token,
scopes=access_token.scopes or [], scopes=access_token.scopes or [],
session_id=session_id, session_id=session_id,
issuer="https://accounts.google.com" issuer="https://accounts.google.com",
) )
logger.debug(f"Successfully verified JWT token for user: {user_email}") logger.debug(
f"Successfully verified JWT token for user: {user_email}"
)
return access_token return access_token

View File

@@ -160,21 +160,6 @@ def set_auth_layer(auth_layer):
logger.info("set_auth_layer called - OAuth is now handled by FastMCP") logger.info("set_auth_layer called - OAuth is now handled by FastMCP")
def is_oauth21_enabled() -> bool:
"""
Check if the OAuth 2.1 authentication layer is active.
Uses centralized configuration from oauth_config.
"""
from auth.oauth_config import is_oauth21_enabled as config_oauth21_enabled
return config_oauth21_enabled()
def enable_oauth21():
"""
Enable the OAuth 2.1 authentication layer.
Note: This is now controlled by MCP_ENABLE_OAUTH21 env var via oauth_config.
"""
logger.debug("OAuth 2.1 authentication enable request - controlled by MCP_ENABLE_OAUTH21 env var")
async def get_legacy_auth_service( async def get_legacy_auth_service(
@@ -274,7 +259,6 @@ async def get_authenticated_google_service_oauth21_v2(
Returns: Returns:
Tuple of (service instance, actual user email) Tuple of (service instance, actual user email)
""" """
from auth.oauth_types import OAuth21ServiceRequest
# Delegate to the original function for now # Delegate to the original function for now
# This provides a migration path while maintaining backward compatibility # This provides a migration path while maintaining backward compatibility

View File

@@ -486,7 +486,7 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
logger.debug(f"[{tool_name}] Could not get FastMCP context: {e}") logger.debug(f"[{tool_name}] Could not get FastMCP context: {e}")
# Use the same logic as single service decorator # Use the same logic as single service decorator
from auth.oauth21_integration import is_oauth21_enabled from auth.oauth_config import is_oauth21_enabled
if is_oauth21_enabled(): if is_oauth21_enabled():
logger.debug(f"[{tool_name}] Attempting OAuth 2.1 authentication flow for {service_type}.") logger.debug(f"[{tool_name}] Attempting OAuth 2.1 authentication flow for {service_type}.")

View File

@@ -101,8 +101,7 @@ def configure_server_for_http():
_auth_provider = GoogleRemoteAuthProvider() _auth_provider = GoogleRemoteAuthProvider()
server.auth = _auth_provider server.auth = _auth_provider
set_auth_provider(_auth_provider) set_auth_provider(_auth_provider)
from auth.oauth21_integration import enable_oauth21 logger.debug("OAuth 2.1 authentication enabled")
enable_oauth21() # This is now just a logging call
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize GoogleRemoteAuthProvider: {e}", exc_info=True) logger.error(f"Failed to initialize GoogleRemoteAuthProvider: {e}", exc_info=True)
else: else:

View File

@@ -6,14 +6,10 @@ from importlib import metadata
from dotenv import load_dotenv from dotenv import load_dotenv
# Load environment variables from .env file BEFORE any other imports # Load environment variables from .env file BEFORE any other imports
# This ensures OAuth config gets the right environment variables
dotenv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.env') dotenv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.env')
load_dotenv(dotenv_path=dotenv_path) load_dotenv(dotenv_path=dotenv_path)
# Now import modules that depend on environment variables
from core.server import server, set_transport_mode, configure_server_for_http from core.server import server, set_transport_mode, configure_server_for_http
# Reload OAuth config after loading .env to pick up credentials
from auth.oauth_config import reload_oauth_config from auth.oauth_config import reload_oauth_config
reload_oauth_config() reload_oauth_config()