source of truth for oauth2.1 enablement

This commit is contained in:
Taylor Wilsdon
2025-08-09 11:53:00 -04:00
parent 374dc9c3e7
commit 773645794a
6 changed files with 93 additions and 86 deletions

View File

@@ -24,6 +24,7 @@ from pydantic import AnyHttpUrl
try:
from fastmcp.server.auth import RemoteAuthProvider
from fastmcp.server.auth.providers.jwt import JWTVerifier
REMOTEAUTHPROVIDER_AVAILABLE = True
except ImportError:
REMOTEAUTHPROVIDER_AVAILABLE = False
@@ -38,7 +39,7 @@ from auth.oauth_common_handlers import (
handle_oauth_protected_resource,
handle_oauth_authorization_server,
handle_oauth_client_config,
handle_oauth_register
handle_oauth_register,
)
logger = logging.getLogger(__name__)
@@ -52,9 +53,6 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
- OAuth proxy endpoints for CORS workaround
- Dynamic client registration support
- Session management with issuer tracking
VS Code compatibility is now handled transparently by middleware,
eliminating the need for custom redirects and path handling.
"""
def __init__(self):
@@ -69,15 +67,19 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
self.port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)))
if not self.client_id:
logger.error("GOOGLE_OAUTH_CLIENT_ID not set - OAuth 2.1 authentication will not work")
raise ValueError("GOOGLE_OAUTH_CLIENT_ID environment variable is required for OAuth 2.1 authentication")
logger.error(
"GOOGLE_OAUTH_CLIENT_ID not set - OAuth 2.1 authentication will not work"
)
raise ValueError(
"GOOGLE_OAUTH_CLIENT_ID environment variable is required for OAuth 2.1 authentication"
)
# Configure JWT verifier for Google tokens
token_verifier = JWTVerifier(
jwks_uri="https://www.googleapis.com/oauth2/v3/certs",
issuer="https://accounts.google.com",
audience=self.client_id, # Always use actual client_id
algorithm="RS256"
algorithm="RS256",
)
# Initialize RemoteAuthProvider with base URL (no /mcp/ suffix)
@@ -85,55 +87,72 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
super().__init__(
token_verifier=token_verifier,
authorization_servers=[AnyHttpUrl(f"{self.base_url}:{self.port}")],
resource_server_url=f"{self.base_url}:{self.port}"
resource_server_url=f"{self.base_url}:{self.port}",
)
logger.debug("GoogleRemoteAuthProvider initialized with VS Code compatibility")
logger.debug("GoogleRemoteAuthProvider")
def get_routes(self) -> List[Route]:
"""
Add OAuth routes at canonical locations.
VS Code compatibility is now handled transparently by middleware,
so we only need to register routes at their canonical locations.
"""
# Get the standard OAuth protected resource routes from RemoteAuthProvider
parent_routes = super().get_routes()
# Filter out the parent's oauth-protected-resource route since we're replacing it
routes = [r for r in parent_routes if r.path != "/.well-known/oauth-protected-resource"]
routes = [
r
for r in parent_routes
if r.path != "/.well-known/oauth-protected-resource"
]
# Add our custom OAuth discovery endpoint that returns /mcp/ as the resource
routes.append(Route(
"/.well-known/oauth-protected-resource",
handle_oauth_protected_resource,
methods=["GET", "OPTIONS"]
))
routes.append(
Route(
"/.well-known/oauth-protected-resource",
handle_oauth_protected_resource,
methods=["GET", "OPTIONS"],
)
)
routes.append(Route(
"/.well-known/oauth-authorization-server",
handle_oauth_authorization_server,
methods=["GET", "OPTIONS"]
))
routes.append(
Route(
"/.well-known/oauth-authorization-server",
handle_oauth_authorization_server,
methods=["GET", "OPTIONS"],
)
)
routes.append(Route(
"/.well-known/oauth-client",
handle_oauth_client_config,
methods=["GET", "OPTIONS"]
))
routes.append(
Route(
"/.well-known/oauth-client",
handle_oauth_client_config,
methods=["GET", "OPTIONS"],
)
)
# Add OAuth flow endpoints
routes.append(Route("/oauth2/authorize", handle_oauth_authorize, methods=["GET", "OPTIONS"]))
routes.append(Route("/oauth2/token", handle_proxy_token_exchange, methods=["POST", "OPTIONS"]))
routes.append(Route("/oauth2/register", handle_oauth_register, methods=["POST", "OPTIONS"]))
routes.append(
Route(
"/oauth2/authorize", handle_oauth_authorize, methods=["GET", "OPTIONS"]
)
)
routes.append(
Route(
"/oauth2/token",
handle_proxy_token_exchange,
methods=["POST", "OPTIONS"],
)
)
routes.append(
Route(
"/oauth2/register", handle_oauth_register, methods=["POST", "OPTIONS"]
)
)
logger.info(f"Registered {len(routes)} OAuth routes")
return routes
async def verify_token(self, token: str) -> Optional[object]:
"""
Override verify_token to handle Google OAuth access tokens.
@@ -143,22 +162,30 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
"""
# Check if this is a Google OAuth access token (starts with ya29.)
if token.startswith("ya29."):
logger.debug("Detected Google OAuth access token, using tokeninfo verification")
logger.debug(
"Detected Google OAuth access token, using tokeninfo verification"
)
try:
# Verify the access token using Google's tokeninfo endpoint
async with aiohttp.ClientSession() as session:
url = f"https://oauth2.googleapis.com/tokeninfo?access_token={token}"
url = (
f"https://oauth2.googleapis.com/tokeninfo?access_token={token}"
)
async with session.get(url) as response:
if response.status != 200:
logger.error(f"Token verification failed: {response.status}")
logger.error(
f"Token verification failed: {response.status}"
)
return None
token_info = await response.json()
# Verify the token is for our client
if token_info.get("aud") != self.client_id:
logger.error(f"Token audience mismatch: expected {self.client_id}, got {token_info.get('aud')}")
logger.error(
f"Token audience mismatch: expected {self.client_id}, got {token_info.get('aud')}"
)
return None
# Check if token is expired
@@ -173,7 +200,9 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
# Calculate expires_at timestamp
expires_in = int(token_info.get("expires_in", 0))
expires_at = int(time.time()) + expires_in if expires_in > 0 else 0
expires_at = (
int(time.time()) + expires_in if expires_in > 0 else 0
)
access_token = SimpleNamespace(
claims={
@@ -188,12 +217,15 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
client_id=self.client_id, # Add client_id at top level
# Add other required fields
sub=token_info.get("sub", ""),
email=token_info.get("email", "")
email=token_info.get("email", ""),
)
user_email = token_info.get("email")
if user_email:
from auth.oauth21_session_store import get_oauth21_session_store
from auth.oauth21_session_store import (
get_oauth21_session_store,
)
store = get_oauth21_session_store()
session_id = f"google_{token_info.get('sub', 'unknown')}"
@@ -201,10 +233,13 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
mcp_session_id = None
try:
from fastmcp.server.dependencies import get_context
ctx = get_context()
if ctx and hasattr(ctx, 'session_id'):
if ctx and hasattr(ctx, "session_id"):
mcp_session_id = ctx.session_id
logger.debug(f"Binding MCP session {mcp_session_id} to user {user_email}")
logger.debug(
f"Binding MCP session {mcp_session_id} to user {user_email}"
)
except Exception:
pass
@@ -215,7 +250,7 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
scopes=access_token.scopes,
session_id=session_id,
mcp_session_id=mcp_session_id,
issuer="https://accounts.google.com"
issuer="https://accounts.google.com",
)
logger.info(f"Verified OAuth token: {user_email}")
@@ -236,6 +271,7 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
user_email = access_token.claims.get("email")
if user_email:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
session_id = f"google_{access_token.claims.get('sub', 'unknown')}"
@@ -245,9 +281,11 @@ class GoogleRemoteAuthProvider(RemoteAuthProvider):
access_token=token,
scopes=access_token.scopes or [],
session_id=session_id,
issuer="https://accounts.google.com"
issuer="https://accounts.google.com",
)
logger.debug(f"Successfully verified JWT token for user: {user_email}")
logger.debug(
f"Successfully verified JWT token for user: {user_email}"
)
return access_token
return access_token