refactor oauth2.1 support to fastmcp native
This commit is contained in:
@@ -9,6 +9,8 @@ from types import SimpleNamespace
|
||||
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
from auth.oauth21_session_store import ensure_session_from_access_token
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,7 +23,7 @@ class AuthInfoMiddleware(Middleware):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.auth_provider_type = "Unknown"
|
||||
self.auth_provider_type = "GoogleProvider"
|
||||
|
||||
async def _process_request_for_auth(self, context: MiddlewareContext):
|
||||
"""Helper to extract, verify, and store auth info from a request."""
|
||||
@@ -87,6 +89,13 @@ class AuthInfoMiddleware(Middleware):
|
||||
|
||||
# Store in context state - this is the authoritative authentication state
|
||||
context.fastmcp_context.set_state("access_token", access_token)
|
||||
mcp_session_id = getattr(context.fastmcp_context, "session_id", None)
|
||||
ensure_session_from_access_token(
|
||||
verified_auth,
|
||||
user_email,
|
||||
mcp_session_id,
|
||||
)
|
||||
context.fastmcp_context.set_state("access_token_obj", verified_auth)
|
||||
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
|
||||
context.fastmcp_context.set_state("token_type", "google_oauth")
|
||||
context.fastmcp_context.set_state("user_email", user_email)
|
||||
|
||||
@@ -1,172 +0,0 @@
|
||||
"""
|
||||
Google Workspace Authentication Provider for FastMCP
|
||||
|
||||
This module implements OAuth 2.1 authentication for Google Workspace using FastMCP's
|
||||
built-in authentication patterns. It acts as a Resource Server (RS) that trusts
|
||||
Google as the Authorization Server (AS).
|
||||
|
||||
Key features:
|
||||
- JWT token verification using Google's public keys
|
||||
- Discovery metadata endpoints for MCP protocol compliance
|
||||
- CORS proxy endpoints to work around Google's CORS limitations
|
||||
- Session bridging to Google credentials for API access
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
from starlette.routing import Route
|
||||
|
||||
from fastmcp.server.auth.auth import AuthProvider
|
||||
from fastmcp.server.auth.providers.jwt import JWTVerifier
|
||||
from mcp.server.auth.provider import AccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleWorkspaceAuthProvider(AuthProvider):
|
||||
"""
|
||||
Authentication provider for Google Workspace integration.
|
||||
|
||||
This provider implements the Remote Authentication pattern where:
|
||||
- Google acts as the Authorization Server (AS)
|
||||
- This MCP server acts as a Resource Server (RS)
|
||||
- Tokens are verified using Google's public keys
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Google Workspace auth provider."""
|
||||
super().__init__()
|
||||
|
||||
# Get configuration from OAuth config
|
||||
from auth.oauth_config import get_oauth_config
|
||||
config = get_oauth_config()
|
||||
|
||||
self.client_id = config.client_id
|
||||
self.client_secret = config.client_secret
|
||||
self.base_url = config.get_oauth_base_url()
|
||||
self.port = config.port
|
||||
|
||||
if not self.client_id:
|
||||
logger.warning("GOOGLE_OAUTH_CLIENT_ID not set - OAuth 2.1 authentication will not work")
|
||||
return
|
||||
|
||||
# Initialize JWT verifier for Google tokens
|
||||
self.jwt_verifier = JWTVerifier(
|
||||
jwks_uri="https://www.googleapis.com/oauth2/v3/certs",
|
||||
issuer="https://accounts.google.com",
|
||||
audience=self.client_id,
|
||||
algorithm="RS256"
|
||||
)
|
||||
|
||||
# Session bridging now handled by OAuth21SessionStore
|
||||
|
||||
async def verify_token(self, token: str) -> Optional[AccessToken]:
|
||||
"""
|
||||
Verify a bearer token issued by Google.
|
||||
|
||||
Args:
|
||||
token: The bearer token to verify
|
||||
|
||||
Returns:
|
||||
AccessToken object if valid, None otherwise
|
||||
"""
|
||||
if not self.client_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Use FastMCP's JWT verifier
|
||||
access_token = await self.jwt_verifier.verify_token(token)
|
||||
|
||||
if access_token:
|
||||
# Store session info in OAuth21SessionStore for credential bridging
|
||||
user_email = access_token.claims.get("email")
|
||||
if user_email:
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
store = get_oauth21_session_store()
|
||||
session_id = f"google_{access_token.claims.get('sub', 'unknown')}"
|
||||
|
||||
# Try to get FastMCP session ID for binding
|
||||
mcp_session_id = None
|
||||
try:
|
||||
from fastmcp.server.dependencies import get_context
|
||||
ctx = get_context()
|
||||
if ctx and hasattr(ctx, 'session_id'):
|
||||
mcp_session_id = ctx.session_id
|
||||
logger.debug(f"Binding MCP session {mcp_session_id} to user {user_email}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
store.store_session(
|
||||
user_email=user_email,
|
||||
access_token=token,
|
||||
scopes=access_token.scopes or [],
|
||||
session_id=session_id,
|
||||
mcp_session_id=mcp_session_id
|
||||
)
|
||||
|
||||
logger.debug(f"Successfully verified Google token for user: {user_email}")
|
||||
|
||||
return access_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify Google token: {e}")
|
||||
return None
|
||||
|
||||
def customize_auth_routes(self, routes: List[Route]) -> List[Route]:
|
||||
"""
|
||||
NOTE: This method is not currently used. All OAuth 2.1 routes are implemented
|
||||
directly in core/server.py using @server.custom_route decorators.
|
||||
|
||||
This method exists for compatibility with FastMCP's AuthProvider interface
|
||||
but the routes it would define are handled elsewhere.
|
||||
"""
|
||||
# Routes are implemented directly in core/server.py
|
||||
return routes
|
||||
|
||||
def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get session information for credential bridging from OAuth21SessionStore.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
|
||||
Returns:
|
||||
Session information if found
|
||||
"""
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
# Try to get user by session_id (assuming it's the MCP session ID)
|
||||
user_email = store.get_user_by_mcp_session(session_id)
|
||||
if user_email:
|
||||
credentials = store.get_credentials(user_email)
|
||||
if credentials:
|
||||
return {
|
||||
"access_token": credentials.token,
|
||||
"user_email": user_email,
|
||||
"scopes": credentials.scopes or []
|
||||
}
|
||||
return None
|
||||
|
||||
def create_session_from_token(self, token: str, user_email: str) -> str:
|
||||
"""
|
||||
Create a session from an access token for credential bridging using OAuth21SessionStore.
|
||||
|
||||
Args:
|
||||
token: The access token
|
||||
user_email: The user's email address
|
||||
|
||||
Returns:
|
||||
Session ID
|
||||
"""
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
store = get_oauth21_session_store()
|
||||
session_id = f"google_{user_email}"
|
||||
|
||||
store.store_session(
|
||||
user_email=user_email,
|
||||
access_token=token,
|
||||
session_id=session_id
|
||||
)
|
||||
return session_id
|
||||
@@ -1,295 +0,0 @@
|
||||
"""
|
||||
Google Workspace RemoteAuthProvider for FastMCP v2.11.1+
|
||||
|
||||
This module implements OAuth 2.1 authentication for Google Workspace using FastMCP's
|
||||
RemoteAuthProvider pattern. It provides:
|
||||
|
||||
- JWT token verification using Google's public keys
|
||||
- OAuth proxy endpoints to work around CORS restrictions
|
||||
- Dynamic client registration workaround
|
||||
- Session bridging to Google credentials for API access
|
||||
|
||||
This provider is used only in streamable-http transport mode with FastMCP v2.11.1+.
|
||||
For earlier versions or other transport modes, the legacy GoogleWorkspaceAuthProvider is used.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import aiohttp
|
||||
from typing import Optional, List
|
||||
|
||||
from starlette.routing import Route
|
||||
from pydantic import AnyHttpUrl
|
||||
|
||||
try:
|
||||
from fastmcp.server.auth import RemoteAuthProvider
|
||||
from fastmcp.server.auth.providers.jwt import JWTVerifier
|
||||
|
||||
REMOTEAUTHPROVIDER_AVAILABLE = True
|
||||
except ImportError:
|
||||
REMOTEAUTHPROVIDER_AVAILABLE = False
|
||||
RemoteAuthProvider = object # Fallback for type hints
|
||||
JWTVerifier = object
|
||||
|
||||
|
||||
# Import common OAuth handlers
|
||||
from auth.oauth_common_handlers import (
|
||||
handle_oauth_authorize,
|
||||
handle_proxy_token_exchange,
|
||||
handle_oauth_protected_resource,
|
||||
handle_oauth_authorization_server,
|
||||
handle_oauth_client_config,
|
||||
handle_oauth_register,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleRemoteAuthProvider(RemoteAuthProvider):
|
||||
"""
|
||||
RemoteAuthProvider implementation for Google Workspace.
|
||||
|
||||
This provider extends RemoteAuthProvider to add:
|
||||
- OAuth proxy endpoints for CORS workaround
|
||||
- Dynamic client registration support
|
||||
- Session management with issuer tracking
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Google RemoteAuthProvider."""
|
||||
if not REMOTEAUTHPROVIDER_AVAILABLE:
|
||||
raise ImportError("FastMCP v2.11.1+ required for RemoteAuthProvider")
|
||||
|
||||
# Get configuration from OAuth config
|
||||
from auth.oauth_config import get_oauth_config
|
||||
config = get_oauth_config()
|
||||
|
||||
self.client_id = config.client_id
|
||||
self.client_secret = config.client_secret
|
||||
self.base_url = config.get_oauth_base_url()
|
||||
self.port = config.port
|
||||
|
||||
if not self.client_id:
|
||||
logger.error(
|
||||
"GOOGLE_OAUTH_CLIENT_ID not set - OAuth 2.1 authentication will not work"
|
||||
)
|
||||
raise ValueError(
|
||||
"GOOGLE_OAUTH_CLIENT_ID environment variable is required for OAuth 2.1 authentication"
|
||||
)
|
||||
|
||||
# Configure JWT verifier for Google tokens
|
||||
token_verifier = JWTVerifier(
|
||||
jwks_uri="https://www.googleapis.com/oauth2/v3/certs",
|
||||
issuer="https://accounts.google.com",
|
||||
audience=self.client_id, # Always use actual client_id
|
||||
algorithm="RS256",
|
||||
)
|
||||
|
||||
# Initialize RemoteAuthProvider with base URL (no /mcp/ suffix)
|
||||
# The /mcp/ resource URL is handled in the protected resource metadata endpoint
|
||||
super().__init__(
|
||||
token_verifier=token_verifier,
|
||||
authorization_servers=[AnyHttpUrl(self.base_url)],
|
||||
resource_server_url=self.base_url,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Initialized GoogleRemoteAuthProvider with base_url={self.base_url}"
|
||||
)
|
||||
|
||||
def get_routes(self) -> List[Route]:
|
||||
"""
|
||||
Add OAuth routes at canonical locations.
|
||||
"""
|
||||
# Get the standard OAuth protected resource routes from RemoteAuthProvider
|
||||
parent_routes = super().get_routes()
|
||||
|
||||
# Filter out the parent's oauth-protected-resource route since we're replacing it
|
||||
routes = [
|
||||
r
|
||||
for r in parent_routes
|
||||
if r.path != "/.well-known/oauth-protected-resource"
|
||||
]
|
||||
|
||||
# Add our custom OAuth discovery endpoint that returns /mcp/ as the resource
|
||||
routes.append(
|
||||
Route(
|
||||
"/.well-known/oauth-protected-resource",
|
||||
handle_oauth_protected_resource,
|
||||
methods=["GET", "OPTIONS"],
|
||||
)
|
||||
)
|
||||
|
||||
routes.append(
|
||||
Route(
|
||||
"/.well-known/oauth-authorization-server",
|
||||
handle_oauth_authorization_server,
|
||||
methods=["GET", "OPTIONS"],
|
||||
)
|
||||
)
|
||||
|
||||
routes.append(
|
||||
Route(
|
||||
"/.well-known/oauth-client",
|
||||
handle_oauth_client_config,
|
||||
methods=["GET", "OPTIONS"],
|
||||
)
|
||||
)
|
||||
|
||||
# Add OAuth flow endpoints
|
||||
routes.append(
|
||||
Route(
|
||||
"/oauth2/authorize", handle_oauth_authorize, methods=["GET", "OPTIONS"]
|
||||
)
|
||||
)
|
||||
routes.append(
|
||||
Route(
|
||||
"/oauth2/token",
|
||||
handle_proxy_token_exchange,
|
||||
methods=["POST", "OPTIONS"],
|
||||
)
|
||||
)
|
||||
routes.append(
|
||||
Route(
|
||||
"/oauth2/register", handle_oauth_register, methods=["POST", "OPTIONS"]
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Registered {len(routes)} OAuth routes")
|
||||
return routes
|
||||
|
||||
async def verify_token(self, token: str) -> Optional[object]:
|
||||
"""
|
||||
Override verify_token to handle Google OAuth access tokens.
|
||||
|
||||
Google OAuth access tokens (ya29.*) are opaque tokens that need to be
|
||||
verified using the tokeninfo endpoint, not JWT verification.
|
||||
"""
|
||||
# Check if this is a Google OAuth access token (starts with ya29.)
|
||||
if token.startswith("ya29."):
|
||||
logger.debug(
|
||||
"Detected Google OAuth access token, using tokeninfo verification"
|
||||
)
|
||||
|
||||
try:
|
||||
# Verify the access token using Google's tokeninfo endpoint
|
||||
async with aiohttp.ClientSession() as session:
|
||||
url = (
|
||||
f"https://oauth2.googleapis.com/tokeninfo?access_token={token}"
|
||||
)
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
logger.error(
|
||||
f"Token verification failed: {response.status}"
|
||||
)
|
||||
return None
|
||||
|
||||
token_info = await response.json()
|
||||
|
||||
# Verify the token is for our client
|
||||
if token_info.get("aud") != self.client_id:
|
||||
logger.error(
|
||||
f"Token audience mismatch: expected {self.client_id}, got {token_info.get('aud')}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if token is expired
|
||||
expires_in = token_info.get("expires_in", 0)
|
||||
if int(expires_in) <= 0:
|
||||
logger.error("Token is expired")
|
||||
return None
|
||||
|
||||
# Create an access token object that matches the expected interface
|
||||
from types import SimpleNamespace
|
||||
import time
|
||||
|
||||
# Calculate expires_at timestamp
|
||||
expires_in = int(token_info.get("expires_in", 0))
|
||||
expires_at = (
|
||||
int(time.time()) + expires_in if expires_in > 0 else 0
|
||||
)
|
||||
|
||||
access_token = SimpleNamespace(
|
||||
claims={
|
||||
"email": token_info.get("email"),
|
||||
"sub": token_info.get("sub"),
|
||||
"aud": token_info.get("aud"),
|
||||
"scope": token_info.get("scope", ""),
|
||||
},
|
||||
scopes=token_info.get("scope", "").split(),
|
||||
token=token,
|
||||
expires_at=expires_at, # Add the expires_at attribute
|
||||
client_id=self.client_id, # Add client_id at top level
|
||||
# Add other required fields
|
||||
sub=token_info.get("sub", ""),
|
||||
email=token_info.get("email", ""),
|
||||
)
|
||||
|
||||
user_email = token_info.get("email")
|
||||
if user_email:
|
||||
from auth.oauth21_session_store import (
|
||||
get_oauth21_session_store,
|
||||
)
|
||||
|
||||
store = get_oauth21_session_store()
|
||||
session_id = f"google_{token_info.get('sub', 'unknown')}"
|
||||
|
||||
# Try to get FastMCP session ID for binding
|
||||
mcp_session_id = None
|
||||
try:
|
||||
from fastmcp.server.dependencies import get_context
|
||||
|
||||
ctx = get_context()
|
||||
if ctx and hasattr(ctx, "session_id"):
|
||||
mcp_session_id = ctx.session_id
|
||||
logger.debug(
|
||||
f"Binding MCP session {mcp_session_id} to user {user_email}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Store session with issuer information
|
||||
store.store_session(
|
||||
user_email=user_email,
|
||||
access_token=token,
|
||||
scopes=access_token.scopes,
|
||||
session_id=session_id,
|
||||
mcp_session_id=mcp_session_id,
|
||||
issuer="https://accounts.google.com",
|
||||
)
|
||||
|
||||
logger.info(f"Verified OAuth token: {user_email}")
|
||||
|
||||
return access_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying Google OAuth token: {e}")
|
||||
return None
|
||||
|
||||
else:
|
||||
# For JWT tokens, use parent's JWT verification
|
||||
logger.debug("Using JWT verification for non-OAuth token")
|
||||
access_token = await super().verify_token(token)
|
||||
|
||||
if access_token and self.client_id:
|
||||
# Extract user information from token claims
|
||||
user_email = access_token.claims.get("email")
|
||||
if user_email:
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
|
||||
store = get_oauth21_session_store()
|
||||
session_id = f"google_{access_token.claims.get('sub', 'unknown')}"
|
||||
|
||||
# Store session with issuer information
|
||||
store.store_session(
|
||||
user_email=user_email,
|
||||
access_token=token,
|
||||
scopes=access_token.scopes or [],
|
||||
session_id=session_id,
|
||||
issuer="https://accounts.google.com",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully verified JWT token for user: {user_email}"
|
||||
)
|
||||
|
||||
return access_token
|
||||
@@ -8,11 +8,12 @@ session context management and credential conversion functionality.
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
from threading import RLock
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastmcp.server.auth import AccessToken
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -565,6 +566,131 @@ def get_auth_provider():
|
||||
return _auth_provider
|
||||
|
||||
|
||||
def _resolve_client_credentials() -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Resolve OAuth client credentials from the active provider or configuration."""
|
||||
client_id: Optional[str] = None
|
||||
client_secret: Optional[str] = None
|
||||
|
||||
if _auth_provider:
|
||||
client_id = getattr(_auth_provider, "_upstream_client_id", None)
|
||||
secret_obj = getattr(_auth_provider, "_upstream_client_secret", None)
|
||||
if secret_obj is not None:
|
||||
if hasattr(secret_obj, "get_secret_value"):
|
||||
try:
|
||||
client_secret = secret_obj.get_secret_value() # type: ignore[call-arg]
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug(f"Failed to resolve client secret from provider: {exc}")
|
||||
elif isinstance(secret_obj, str):
|
||||
client_secret = secret_obj
|
||||
|
||||
if not client_id or not client_secret:
|
||||
try:
|
||||
from auth.oauth_config import get_oauth_config
|
||||
|
||||
cfg = get_oauth_config()
|
||||
client_id = client_id or cfg.client_id
|
||||
client_secret = client_secret or cfg.client_secret
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug(f"Failed to resolve client credentials from config: {exc}")
|
||||
|
||||
return client_id, client_secret
|
||||
|
||||
|
||||
def _build_credentials_from_provider(access_token: AccessToken) -> Optional[Credentials]:
|
||||
"""Construct Google credentials from the provider cache."""
|
||||
if not _auth_provider:
|
||||
return None
|
||||
|
||||
access_entry = getattr(_auth_provider, "_access_tokens", {}).get(access_token.token)
|
||||
if not access_entry:
|
||||
access_entry = access_token
|
||||
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
|
||||
refresh_token_value = getattr(_auth_provider, "_access_to_refresh", {}).get(access_token.token)
|
||||
refresh_token_obj = None
|
||||
if refresh_token_value:
|
||||
refresh_token_obj = getattr(_auth_provider, "_refresh_tokens", {}).get(refresh_token_value)
|
||||
|
||||
expiry = None
|
||||
expires_at = getattr(access_entry, "expires_at", None)
|
||||
if expires_at:
|
||||
try:
|
||||
expiry = datetime.utcfromtimestamp(expires_at)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
expiry = None
|
||||
|
||||
scopes = getattr(access_entry, "scopes", None)
|
||||
|
||||
return Credentials(
|
||||
token=access_token.token,
|
||||
refresh_token=refresh_token_obj.token if refresh_token_obj else None,
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scopes=scopes,
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
|
||||
def ensure_session_from_access_token(
|
||||
access_token: AccessToken,
|
||||
user_email: Optional[str],
|
||||
mcp_session_id: Optional[str] = None,
|
||||
) -> Optional[Credentials]:
|
||||
"""Ensure credentials derived from an access token are cached and returned."""
|
||||
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
email = user_email
|
||||
if not email and getattr(access_token, "claims", None):
|
||||
email = access_token.claims.get("email")
|
||||
|
||||
credentials = _build_credentials_from_provider(access_token)
|
||||
|
||||
if credentials is None:
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
expiry = None
|
||||
expires_at = getattr(access_token, "expires_at", None)
|
||||
if expires_at:
|
||||
try:
|
||||
expiry = datetime.utcfromtimestamp(expires_at)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
expiry = None
|
||||
|
||||
credentials = Credentials(
|
||||
token=access_token.token,
|
||||
refresh_token=None,
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scopes=getattr(access_token, "scopes", None),
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
if email:
|
||||
try:
|
||||
store = get_oauth21_session_store()
|
||||
store.store_session(
|
||||
user_email=email,
|
||||
access_token=credentials.token,
|
||||
refresh_token=credentials.refresh_token,
|
||||
token_uri=credentials.token_uri,
|
||||
client_id=credentials.client_id,
|
||||
client_secret=credentials.client_secret,
|
||||
scopes=credentials.scopes,
|
||||
expiry=credentials.expiry,
|
||||
session_id=f"google_{email}",
|
||||
mcp_session_id=mcp_session_id,
|
||||
issuer="https://accounts.google.com",
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug(f"Failed to cache credentials for {email}: {exc}")
|
||||
|
||||
return credentials
|
||||
|
||||
|
||||
def get_credentials_from_token(access_token: str, user_email: Optional[str] = None) -> Optional[Credentials]:
|
||||
"""
|
||||
Convert a bearer token to Google credentials.
|
||||
@@ -576,10 +702,6 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
Returns:
|
||||
Google Credentials object or None
|
||||
"""
|
||||
if not _auth_provider:
|
||||
logger.error("Auth provider not configured")
|
||||
return None
|
||||
|
||||
try:
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
@@ -590,21 +712,29 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
logger.debug(f"Found matching credentials from store for {user_email}")
|
||||
return credentials
|
||||
|
||||
# If the FastMCP provider is managing tokens, sync from provider storage
|
||||
if _auth_provider:
|
||||
access_record = getattr(_auth_provider, "_access_tokens", {}).get(access_token)
|
||||
if access_record:
|
||||
logger.debug("Building credentials from FastMCP provider cache")
|
||||
return ensure_session_from_access_token(access_record, user_email)
|
||||
|
||||
# Otherwise, create minimal credentials with just the access token
|
||||
# Assume token is valid for 1 hour (typical for Google tokens)
|
||||
expiry = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
|
||||
credentials = Credentials(
|
||||
token=access_token,
|
||||
refresh_token=None,
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=_auth_provider.client_id,
|
||||
client_secret=_auth_provider.client_secret,
|
||||
scopes=None, # Will be populated from token claims if available
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scopes=None,
|
||||
expiry=expiry
|
||||
)
|
||||
|
||||
logger.debug("Created Google credentials from bearer token")
|
||||
logger.debug("Created fallback Google credentials from bearer token")
|
||||
return credentials
|
||||
|
||||
except Exception as e:
|
||||
@@ -643,18 +773,23 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
session_id = f"google_{user_email}"
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
scopes = token_response.get("scope", "")
|
||||
scopes_list = scopes.split() if scopes else None
|
||||
expiry = datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600))
|
||||
|
||||
store.store_session(
|
||||
user_email=user_email,
|
||||
access_token=token_response.get("access_token"),
|
||||
refresh_token=token_response.get("refresh_token"),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=_auth_provider.client_id,
|
||||
client_secret=_auth_provider.client_secret,
|
||||
scopes=token_response.get("scope", "").split() if token_response.get("scope") else None,
|
||||
expiry=datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600)),
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scopes=scopes_list,
|
||||
expiry=expiry,
|
||||
session_id=session_id,
|
||||
mcp_session_id=mcp_session_id,
|
||||
issuer="https://accounts.google.com", # Add issuer for Google tokens
|
||||
issuer="https://accounts.google.com",
|
||||
)
|
||||
|
||||
if mcp_session_id:
|
||||
|
||||
@@ -1,422 +0,0 @@
|
||||
"""Common OAuth 2.1 request handlers used by both legacy and modern auth providers."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from urllib.parse import urlencode, parse_qs
|
||||
|
||||
import aiohttp
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, RedirectResponse
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from auth.oauth21_session_store import store_token_session
|
||||
from auth.google_auth import get_credential_store
|
||||
from auth.scopes import get_current_scopes
|
||||
from auth.oauth_config import get_oauth_config, is_stateless_mode
|
||||
from auth.oauth_error_handling import (
|
||||
OAuthError, OAuthValidationError, OAuthConfigurationError,
|
||||
create_oauth_error_response, validate_token_request,
|
||||
validate_registration_request, get_development_cors_headers,
|
||||
log_security_event
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def handle_oauth_authorize(request: Request):
|
||||
"""Common handler for OAuth authorization proxy."""
|
||||
origin = request.headers.get("origin")
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
return JSONResponse(content={}, headers=cors_headers)
|
||||
|
||||
# Get query parameters
|
||||
params = dict(request.query_params)
|
||||
|
||||
# Add our client ID if not provided
|
||||
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
if "client_id" not in params and client_id:
|
||||
params["client_id"] = client_id
|
||||
|
||||
# Ensure response_type is code
|
||||
params["response_type"] = "code"
|
||||
|
||||
# Merge client scopes with scopes for enabled tools only
|
||||
client_scopes = params.get("scope", "").split() if params.get("scope") else []
|
||||
# Include scopes for enabled tools only (not all tools)
|
||||
enabled_tool_scopes = get_current_scopes()
|
||||
all_scopes = set(client_scopes) | set(enabled_tool_scopes)
|
||||
params["scope"] = " ".join(sorted(all_scopes))
|
||||
logger.info(f"OAuth 2.1 authorization: Requesting scopes: {params['scope']}")
|
||||
|
||||
# Build Google authorization URL
|
||||
google_auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(params)
|
||||
|
||||
# Return redirect with development CORS headers if needed
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
return RedirectResponse(
|
||||
url=google_auth_url,
|
||||
status_code=302,
|
||||
headers=cors_headers
|
||||
)
|
||||
|
||||
|
||||
async def handle_proxy_token_exchange(request: Request):
|
||||
"""Common handler for OAuth token exchange proxy with comprehensive error handling."""
|
||||
origin = request.headers.get("origin")
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
return JSONResponse(content={}, headers=cors_headers)
|
||||
try:
|
||||
# Get form data with validation
|
||||
try:
|
||||
body = await request.body()
|
||||
content_type = request.headers.get("content-type", "application/x-www-form-urlencoded")
|
||||
except Exception as e:
|
||||
raise OAuthValidationError(f"Failed to read request body: {e}")
|
||||
|
||||
# Parse and validate form data
|
||||
if content_type and "application/x-www-form-urlencoded" in content_type:
|
||||
try:
|
||||
form_data = parse_qs(body.decode('utf-8'))
|
||||
except Exception as e:
|
||||
raise OAuthValidationError(f"Invalid form data: {e}")
|
||||
|
||||
# Convert to single values and validate
|
||||
request_data = {k: v[0] if v else '' for k, v in form_data.items()}
|
||||
validate_token_request(request_data)
|
||||
|
||||
# Check if client_id is missing (public client)
|
||||
if 'client_id' not in form_data or not form_data['client_id'][0]:
|
||||
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
if client_id:
|
||||
form_data['client_id'] = [client_id]
|
||||
logger.debug("Added missing client_id to token request")
|
||||
|
||||
# Check if client_secret is missing (public client using PKCE)
|
||||
if 'client_secret' not in form_data:
|
||||
client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||
if client_secret:
|
||||
form_data['client_secret'] = [client_secret]
|
||||
logger.debug("Added missing client_secret to token request")
|
||||
|
||||
# Reconstruct body with added credentials
|
||||
body = urlencode(form_data, doseq=True).encode('utf-8')
|
||||
|
||||
# Forward request to Google
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {"Content-Type": content_type}
|
||||
|
||||
async with session.post("https://oauth2.googleapis.com/token", data=body, headers=headers) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
# Log for debugging
|
||||
if response.status != 200:
|
||||
logger.error(f"Token exchange failed: {response.status} - {response_data}")
|
||||
else:
|
||||
logger.info("Token exchange successful")
|
||||
|
||||
# Store the token session for credential bridging
|
||||
if "access_token" in response_data:
|
||||
try:
|
||||
# Extract user email from ID token if present
|
||||
if "id_token" in response_data:
|
||||
# Verify ID token using Google's public keys for security
|
||||
try:
|
||||
# Get Google's public keys for verification
|
||||
jwks_client = PyJWKClient("https://www.googleapis.com/oauth2/v3/certs")
|
||||
|
||||
# Get signing key from JWT header
|
||||
signing_key = jwks_client.get_signing_key_from_jwt(response_data["id_token"])
|
||||
|
||||
# Verify and decode the ID token
|
||||
id_token_claims = jwt.decode(
|
||||
response_data["id_token"],
|
||||
signing_key.key,
|
||||
algorithms=["RS256"],
|
||||
audience=os.getenv("GOOGLE_OAUTH_CLIENT_ID"),
|
||||
issuer="https://accounts.google.com"
|
||||
)
|
||||
user_email = id_token_claims.get("email")
|
||||
email_verified = id_token_claims.get("email_verified")
|
||||
|
||||
if not email_verified:
|
||||
logger.error(f"Email address for user {user_email} is not verified by Google. Aborting session creation.")
|
||||
return JSONResponse(content={"error": "Email address not verified"}, status_code=403)
|
||||
elif user_email:
|
||||
# Try to get FastMCP session ID from request context for binding
|
||||
mcp_session_id = None
|
||||
try:
|
||||
# Check if this is a streamable HTTP request with session
|
||||
if hasattr(request, 'state') and hasattr(request.state, 'session_id'):
|
||||
mcp_session_id = request.state.session_id
|
||||
logger.info(f"Found MCP session ID for binding: {mcp_session_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get MCP session ID: {e}")
|
||||
|
||||
# Store the token session with MCP session binding
|
||||
session_id = store_token_session(response_data, user_email, mcp_session_id)
|
||||
logger.info(f"Stored OAuth session for {user_email} (session: {session_id}, mcp: {mcp_session_id})")
|
||||
|
||||
# Also create and store Google credentials
|
||||
expiry = None
|
||||
if "expires_in" in response_data:
|
||||
# Google auth library expects timezone-naive datetime
|
||||
expiry = datetime.utcnow() + timedelta(seconds=response_data["expires_in"])
|
||||
|
||||
credentials = Credentials(
|
||||
token=response_data["access_token"],
|
||||
refresh_token=response_data.get("refresh_token"),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=os.getenv("GOOGLE_OAUTH_CLIENT_ID"),
|
||||
client_secret=os.getenv("GOOGLE_OAUTH_CLIENT_SECRET"),
|
||||
scopes=response_data.get("scope", "").split() if response_data.get("scope") else None,
|
||||
expiry=expiry
|
||||
)
|
||||
|
||||
# Save credentials to file for legacy auth (skip in stateless mode)
|
||||
if not is_stateless_mode():
|
||||
store = get_credential_store()
|
||||
if not store.store_credential(user_email, credentials):
|
||||
logger.error(f"Failed to save Google credentials for {user_email}")
|
||||
else:
|
||||
logger.info(f"Saved Google credentials for {user_email}")
|
||||
else:
|
||||
logger.info(f"Skipping credential file save in stateless mode for {user_email}")
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.error("ID token has expired - cannot extract user email")
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.error(f"Invalid ID token - cannot extract user email: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify ID token - cannot extract user email: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store OAuth session: {e}")
|
||||
|
||||
# Add development CORS headers
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
response_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Cache-Control": "no-store"
|
||||
}
|
||||
response_headers.update(cors_headers)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status,
|
||||
content=response_data,
|
||||
headers=response_headers
|
||||
)
|
||||
|
||||
except OAuthError as e:
|
||||
log_security_event("oauth_token_exchange_error", {
|
||||
"error_code": e.error_code,
|
||||
"description": e.description
|
||||
}, request)
|
||||
return create_oauth_error_response(e, origin)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in token proxy: {e}", exc_info=True)
|
||||
log_security_event("oauth_token_exchange_unexpected_error", {
|
||||
"error": str(e)
|
||||
}, request)
|
||||
error = OAuthConfigurationError("Internal server error")
|
||||
return create_oauth_error_response(error, origin)
|
||||
|
||||
|
||||
async def handle_oauth_protected_resource(request: Request):
|
||||
"""
|
||||
Handle OAuth protected resource metadata requests.
|
||||
"""
|
||||
origin = request.headers.get("origin")
|
||||
|
||||
# Handle preflight
|
||||
if request.method == "OPTIONS":
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
return JSONResponse(content={}, headers=cors_headers)
|
||||
|
||||
config = get_oauth_config()
|
||||
base_url = config.get_oauth_base_url()
|
||||
|
||||
# For streamable-http transport, the MCP server runs at /mcp
|
||||
# This is the actual resource being protected
|
||||
# As of August, /mcp is now the proper base - prior was /mcp/
|
||||
resource_url = f"{base_url}/mcp"
|
||||
|
||||
# Build metadata response per RFC 9449
|
||||
metadata = {
|
||||
"resource": resource_url, # The MCP server endpoint that needs protection
|
||||
"authorization_servers": [base_url], # Our proxy acts as the auth server
|
||||
"bearer_methods_supported": ["header"],
|
||||
"scopes_supported": get_current_scopes(),
|
||||
"resource_documentation": "https://developers.google.com/workspace",
|
||||
"client_registration_required": True,
|
||||
"client_configuration_endpoint": f"{base_url}/.well-known/oauth-client",
|
||||
}
|
||||
# Log the response for debugging
|
||||
logger.debug(f"Returning protected resource metadata: {metadata}")
|
||||
|
||||
# Add development CORS headers
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
response_headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Cache-Control": "public, max-age=3600"
|
||||
}
|
||||
response_headers.update(cors_headers)
|
||||
|
||||
return JSONResponse(
|
||||
content=metadata,
|
||||
headers=response_headers
|
||||
)
|
||||
|
||||
|
||||
async def handle_oauth_authorization_server(request: Request):
|
||||
"""
|
||||
Handle OAuth authorization server metadata.
|
||||
"""
|
||||
origin = request.headers.get("origin")
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
return JSONResponse(content={}, headers=cors_headers)
|
||||
|
||||
config = get_oauth_config()
|
||||
|
||||
# Get authorization server metadata from centralized config
|
||||
# Pass scopes directly to keep all metadata generation in one place
|
||||
metadata = config.get_authorization_server_metadata(scopes=get_current_scopes())
|
||||
|
||||
logger.debug(f"Returning authorization server metadata: {metadata}")
|
||||
|
||||
# Add development CORS headers
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
response_headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Cache-Control": "public, max-age=3600"
|
||||
}
|
||||
response_headers.update(cors_headers)
|
||||
|
||||
return JSONResponse(
|
||||
content=metadata,
|
||||
headers=response_headers
|
||||
)
|
||||
|
||||
|
||||
async def handle_oauth_client_config(request: Request):
|
||||
"""Common handler for OAuth client configuration."""
|
||||
origin = request.headers.get("origin")
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
return JSONResponse(content={}, headers=cors_headers)
|
||||
|
||||
client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
if not client_id:
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"error": "OAuth not configured"},
|
||||
headers=cors_headers
|
||||
)
|
||||
|
||||
# Get OAuth configuration
|
||||
config = get_oauth_config()
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"client_id": client_id,
|
||||
"client_name": "Google Workspace MCP Server",
|
||||
"client_uri": config.base_url,
|
||||
"redirect_uris": [
|
||||
f"{config.base_url}/oauth2callback",
|
||||
],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"scope": " ".join(get_current_scopes()),
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"code_challenge_methods": config.supported_code_challenge_methods[:1] # Primary method only
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Cache-Control": "public, max-age=3600",
|
||||
**get_development_cors_headers(origin)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def handle_oauth_register(request: Request):
|
||||
"""Common handler for OAuth dynamic client registration with comprehensive error handling."""
|
||||
origin = request.headers.get("origin")
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
return JSONResponse(content={}, headers=cors_headers)
|
||||
|
||||
config = get_oauth_config()
|
||||
|
||||
if not config.is_configured():
|
||||
error = OAuthConfigurationError("OAuth client credentials not configured")
|
||||
return create_oauth_error_response(error, origin)
|
||||
|
||||
try:
|
||||
# Parse and validate the registration request
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception as e:
|
||||
raise OAuthValidationError(f"Invalid JSON in registration request: {e}")
|
||||
|
||||
validate_registration_request(body)
|
||||
logger.info("Dynamic client registration request received")
|
||||
|
||||
# Extract redirect URIs from the request or use defaults
|
||||
redirect_uris = body.get("redirect_uris", [])
|
||||
if not redirect_uris:
|
||||
redirect_uris = config.get_redirect_uris()
|
||||
|
||||
# Build the registration response with our pre-configured credentials
|
||||
response_data = {
|
||||
"client_id": config.client_id,
|
||||
"client_secret": config.client_secret,
|
||||
"client_name": body.get("client_name", "Google Workspace MCP Server"),
|
||||
"client_uri": body.get("client_uri", config.base_url),
|
||||
"redirect_uris": redirect_uris,
|
||||
"grant_types": body.get("grant_types", ["authorization_code", "refresh_token"]),
|
||||
"response_types": body.get("response_types", ["code"]),
|
||||
"scope": body.get("scope", " ".join(get_current_scopes())),
|
||||
"token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_basic"),
|
||||
"code_challenge_methods": config.supported_code_challenge_methods,
|
||||
# Additional OAuth 2.1 fields
|
||||
"client_id_issued_at": int(time.time()),
|
||||
"registration_access_token": "not-required", # We don't implement client management
|
||||
"registration_client_uri": f"{config.get_oauth_base_url()}/oauth2/register/{config.client_id}"
|
||||
}
|
||||
|
||||
logger.info("Dynamic client registration successful - returning pre-configured Google credentials")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=201,
|
||||
content=response_data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Cache-Control": "no-store",
|
||||
**get_development_cors_headers(origin)
|
||||
}
|
||||
)
|
||||
|
||||
except OAuthError as e:
|
||||
log_security_event("oauth_registration_error", {
|
||||
"error_code": e.error_code,
|
||||
"description": e.description
|
||||
}, request)
|
||||
return create_oauth_error_response(e, origin)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in client registration: {e}", exc_info=True)
|
||||
log_security_event("oauth_registration_unexpected_error", {
|
||||
"error": str(e)
|
||||
}, request)
|
||||
error = OAuthConfigurationError("Internal server error")
|
||||
return create_oauth_error_response(error, origin)
|
||||
@@ -9,6 +9,7 @@ Supports both OAuth 2.0 and OAuth 2.1 with automatic client capability detection
|
||||
"""
|
||||
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
|
||||
@@ -49,6 +50,10 @@ class OAuthConfig:
|
||||
|
||||
# Redirect URI configuration
|
||||
self.redirect_uri = self._get_redirect_uri()
|
||||
self.redirect_path = self._get_redirect_path(self.redirect_uri)
|
||||
|
||||
# Ensure FastMCP's Google provider picks up our existing configuration
|
||||
self._apply_fastmcp_google_env()
|
||||
|
||||
def _get_redirect_uri(self) -> str:
|
||||
"""
|
||||
@@ -62,6 +67,32 @@ class OAuthConfig:
|
||||
return explicit_uri
|
||||
return f"{self.base_url}/oauth2callback"
|
||||
|
||||
@staticmethod
|
||||
def _get_redirect_path(uri: str) -> str:
|
||||
"""Extract the redirect path from a full redirect URI."""
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme or parsed.netloc:
|
||||
path = parsed.path or "/oauth2callback"
|
||||
else:
|
||||
# If the value was already a path, ensure it starts with '/'
|
||||
path = uri if uri.startswith("/") else f"/{uri}"
|
||||
return path or "/oauth2callback"
|
||||
|
||||
def _apply_fastmcp_google_env(self) -> None:
|
||||
"""Mirror legacy GOOGLE_* env vars into FastMCP Google provider settings."""
|
||||
if not self.client_id:
|
||||
return
|
||||
|
||||
def _set_if_absent(key: str, value: Optional[str]) -> None:
|
||||
if value and key not in os.environ:
|
||||
os.environ[key] = value
|
||||
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH", "fastmcp.server.auth.providers.google.GoogleProvider" if self.oauth21_enabled else None)
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_CLIENT_ID", self.client_id)
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_CLIENT_SECRET", self.client_secret)
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_BASE_URL", self.get_oauth_base_url())
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_REDIRECT_PATH", self.redirect_path)
|
||||
|
||||
def get_redirect_uris(self) -> List[str]:
|
||||
"""
|
||||
Get all valid OAuth redirect URIs.
|
||||
@@ -156,6 +187,7 @@ class OAuthConfig:
|
||||
"external_url": self.external_url,
|
||||
"effective_oauth_url": self.get_oauth_base_url(),
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"redirect_path": self.redirect_path,
|
||||
"client_configured": bool(self.client_id),
|
||||
"oauth21_enabled": self.oauth21_enabled,
|
||||
"pkce_required": self.pkce_required,
|
||||
@@ -350,4 +382,4 @@ def get_oauth_redirect_uri() -> str:
|
||||
|
||||
def is_stateless_mode() -> bool:
|
||||
"""Check if stateless mode is enabled."""
|
||||
return get_oauth_config().stateless_mode
|
||||
return get_oauth_config().stateless_mode
|
||||
|
||||
@@ -1,321 +0,0 @@
|
||||
"""
|
||||
OAuth Error Handling and Validation
|
||||
|
||||
This module provides comprehensive error handling and input validation for OAuth
|
||||
endpoints, addressing the inconsistent error handling identified in the challenge review.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.requests import Request
|
||||
from urllib.parse import urlparse
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthError(Exception):
|
||||
"""Base exception for OAuth-related errors."""
|
||||
|
||||
def __init__(self, error_code: str, description: str, status_code: int = 400):
|
||||
self.error_code = error_code
|
||||
self.description = description
|
||||
self.status_code = status_code
|
||||
super().__init__(f"{error_code}: {description}")
|
||||
|
||||
|
||||
class OAuthValidationError(OAuthError):
|
||||
"""Exception for OAuth validation errors."""
|
||||
|
||||
def __init__(self, description: str, field: Optional[str] = None):
|
||||
error_code = "invalid_request"
|
||||
if field:
|
||||
description = f"Invalid {field}: {description}"
|
||||
super().__init__(error_code, description, 400)
|
||||
|
||||
|
||||
class OAuthConfigurationError(OAuthError):
|
||||
"""Exception for OAuth configuration errors."""
|
||||
|
||||
def __init__(self, description: str):
|
||||
super().__init__("server_error", description, 500)
|
||||
|
||||
|
||||
def create_oauth_error_response(error: OAuthError, origin: Optional[str] = None) -> JSONResponse:
|
||||
"""
|
||||
Create a standardized OAuth error response.
|
||||
|
||||
Args:
|
||||
error: The OAuth error to convert to a response
|
||||
origin: Optional origin for development CORS headers
|
||||
|
||||
Returns:
|
||||
JSONResponse with standardized error format
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Cache-Control": "no-store"
|
||||
}
|
||||
|
||||
# Add development CORS headers if needed
|
||||
cors_headers = get_development_cors_headers(origin)
|
||||
headers.update(cors_headers)
|
||||
|
||||
content = {
|
||||
"error": error.error_code,
|
||||
"error_description": error.description
|
||||
}
|
||||
|
||||
logger.warning(f"OAuth error response: {error.error_code} - {error.description}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=error.status_code,
|
||||
content=content,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
|
||||
def validate_redirect_uri(uri: str) -> None:
|
||||
"""
|
||||
Validate an OAuth redirect URI.
|
||||
|
||||
Args:
|
||||
uri: The redirect URI to validate
|
||||
|
||||
Raises:
|
||||
OAuthValidationError: If the URI is invalid
|
||||
"""
|
||||
if not uri:
|
||||
raise OAuthValidationError("Redirect URI is required", "redirect_uri")
|
||||
|
||||
try:
|
||||
parsed = urlparse(uri)
|
||||
except Exception:
|
||||
raise OAuthValidationError("Malformed redirect URI", "redirect_uri")
|
||||
|
||||
# Basic URI validation
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
raise OAuthValidationError("Redirect URI must be absolute", "redirect_uri")
|
||||
|
||||
# Security checks
|
||||
if parsed.scheme not in ["http", "https"]:
|
||||
raise OAuthValidationError("Redirect URI must use HTTP or HTTPS", "redirect_uri")
|
||||
|
||||
# Additional security for production
|
||||
if parsed.scheme == "http" and parsed.hostname not in ["localhost", "127.0.0.1"]:
|
||||
logger.warning(f"Insecure redirect URI: {uri}")
|
||||
|
||||
|
||||
def validate_client_id(client_id: str) -> None:
|
||||
"""
|
||||
Validate an OAuth client ID.
|
||||
|
||||
Args:
|
||||
client_id: The client ID to validate
|
||||
|
||||
Raises:
|
||||
OAuthValidationError: If the client ID is invalid
|
||||
"""
|
||||
if not client_id:
|
||||
raise OAuthValidationError("Client ID is required", "client_id")
|
||||
|
||||
if len(client_id) < 10:
|
||||
raise OAuthValidationError("Client ID is too short", "client_id")
|
||||
|
||||
# Basic format validation for Google client IDs
|
||||
if not re.match(r'^[a-zA-Z0-9\-_.]+$', client_id):
|
||||
raise OAuthValidationError("Client ID contains invalid characters", "client_id")
|
||||
|
||||
|
||||
def validate_authorization_code(code: str) -> None:
|
||||
"""
|
||||
Validate an OAuth authorization code.
|
||||
|
||||
Args:
|
||||
code: The authorization code to validate
|
||||
|
||||
Raises:
|
||||
OAuthValidationError: If the code is invalid
|
||||
"""
|
||||
if not code:
|
||||
raise OAuthValidationError("Authorization code is required", "code")
|
||||
|
||||
if len(code) < 10:
|
||||
raise OAuthValidationError("Authorization code is too short", "code")
|
||||
|
||||
# Check for suspicious patterns
|
||||
if any(char in code for char in [' ', '\n', '\t', '<', '>']):
|
||||
raise OAuthValidationError("Authorization code contains invalid characters", "code")
|
||||
|
||||
|
||||
def validate_scopes(scopes: List[str]) -> None:
|
||||
"""
|
||||
Validate OAuth scopes.
|
||||
|
||||
Args:
|
||||
scopes: List of scopes to validate
|
||||
|
||||
Raises:
|
||||
OAuthValidationError: If the scopes are invalid
|
||||
"""
|
||||
if not scopes:
|
||||
return # Empty scopes list is acceptable
|
||||
|
||||
for scope in scopes:
|
||||
if not scope:
|
||||
raise OAuthValidationError("Empty scope is not allowed", "scope")
|
||||
|
||||
if len(scope) > 200:
|
||||
raise OAuthValidationError("Scope is too long", "scope")
|
||||
|
||||
# Basic scope format validation
|
||||
if not re.match(r'^[a-zA-Z0-9\-_.:/]+$', scope):
|
||||
raise OAuthValidationError(f"Invalid scope format: {scope}", "scope")
|
||||
|
||||
|
||||
def validate_token_request(request_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate an OAuth token exchange request.
|
||||
|
||||
Args:
|
||||
request_data: The token request data to validate
|
||||
|
||||
Raises:
|
||||
OAuthValidationError: If the request is invalid
|
||||
"""
|
||||
grant_type = request_data.get("grant_type")
|
||||
if not grant_type:
|
||||
raise OAuthValidationError("Grant type is required", "grant_type")
|
||||
|
||||
if grant_type not in ["authorization_code", "refresh_token"]:
|
||||
raise OAuthValidationError(f"Unsupported grant type: {grant_type}", "grant_type")
|
||||
|
||||
if grant_type == "authorization_code":
|
||||
code = request_data.get("code")
|
||||
validate_authorization_code(code)
|
||||
|
||||
redirect_uri = request_data.get("redirect_uri")
|
||||
if redirect_uri:
|
||||
validate_redirect_uri(redirect_uri)
|
||||
|
||||
client_id = request_data.get("client_id")
|
||||
if client_id:
|
||||
validate_client_id(client_id)
|
||||
|
||||
|
||||
def validate_registration_request(request_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate an OAuth client registration request.
|
||||
|
||||
Args:
|
||||
request_data: The registration request data to validate
|
||||
|
||||
Raises:
|
||||
OAuthValidationError: If the request is invalid
|
||||
"""
|
||||
# Validate redirect URIs if provided
|
||||
redirect_uris = request_data.get("redirect_uris", [])
|
||||
if redirect_uris:
|
||||
if not isinstance(redirect_uris, list):
|
||||
raise OAuthValidationError("redirect_uris must be an array", "redirect_uris")
|
||||
|
||||
for uri in redirect_uris:
|
||||
validate_redirect_uri(uri)
|
||||
|
||||
# Validate grant types if provided
|
||||
grant_types = request_data.get("grant_types", [])
|
||||
if grant_types:
|
||||
if not isinstance(grant_types, list):
|
||||
raise OAuthValidationError("grant_types must be an array", "grant_types")
|
||||
|
||||
allowed_grant_types = ["authorization_code", "refresh_token"]
|
||||
for grant_type in grant_types:
|
||||
if grant_type not in allowed_grant_types:
|
||||
raise OAuthValidationError(f"Unsupported grant type: {grant_type}", "grant_types")
|
||||
|
||||
# Validate response types if provided
|
||||
response_types = request_data.get("response_types", [])
|
||||
if response_types:
|
||||
if not isinstance(response_types, list):
|
||||
raise OAuthValidationError("response_types must be an array", "response_types")
|
||||
|
||||
allowed_response_types = ["code"]
|
||||
for response_type in response_types:
|
||||
if response_type not in allowed_response_types:
|
||||
raise OAuthValidationError(f"Unsupported response type: {response_type}", "response_types")
|
||||
|
||||
|
||||
def sanitize_user_input(value: str, max_length: int = 1000) -> str:
|
||||
"""
|
||||
Sanitize user input to prevent injection attacks.
|
||||
|
||||
Args:
|
||||
value: The input value to sanitize
|
||||
max_length: Maximum allowed length
|
||||
|
||||
Returns:
|
||||
Sanitized input value
|
||||
|
||||
Raises:
|
||||
OAuthValidationError: If the input is invalid
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise OAuthValidationError("Input must be a string")
|
||||
|
||||
if len(value) > max_length:
|
||||
raise OAuthValidationError(f"Input is too long (max {max_length} characters)")
|
||||
|
||||
# Remove potentially dangerous characters
|
||||
sanitized = re.sub(r'[<>"\'\0\n\r\t]', '', value)
|
||||
|
||||
return sanitized.strip()
|
||||
|
||||
|
||||
def log_security_event(event_type: str, details: Dict[str, Any], request: Optional[Request] = None) -> None:
|
||||
"""
|
||||
Log security-related events for monitoring.
|
||||
|
||||
Args:
|
||||
event_type: Type of security event
|
||||
details: Event details
|
||||
request: Optional request object for context
|
||||
"""
|
||||
log_data = {
|
||||
"event_type": event_type,
|
||||
"details": details
|
||||
}
|
||||
|
||||
if request:
|
||||
log_data["request"] = {
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"user_agent": request.headers.get("user-agent", "unknown"),
|
||||
"origin": request.headers.get("origin", "unknown")
|
||||
}
|
||||
|
||||
logger.warning(f"Security event: {log_data}")
|
||||
|
||||
|
||||
def get_development_cors_headers(origin: Optional[str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Get minimal CORS headers for development scenarios only.
|
||||
|
||||
Only allows localhost origins for development tools and inspectors.
|
||||
|
||||
Args:
|
||||
origin: The request origin (will be validated)
|
||||
|
||||
Returns:
|
||||
CORS headers for localhost origins only, empty dict otherwise
|
||||
"""
|
||||
# Only allow localhost origins for development
|
||||
if origin and (origin.startswith("http://localhost:") or origin.startswith("http://127.0.0.1:")):
|
||||
return {
|
||||
"Access-Control-Allow-Origin": origin,
|
||||
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization",
|
||||
"Access-Control-Max-Age": "3600"
|
||||
}
|
||||
|
||||
return {}
|
||||
@@ -7,9 +7,13 @@ from typing import Dict, List, Optional, Any, Callable, Union, Tuple
|
||||
|
||||
from google.auth.exceptions import RefreshError
|
||||
from googleapiclient.discovery import build
|
||||
from fastmcp.server.dependencies import get_context
|
||||
from fastmcp.server.dependencies import get_access_token, get_context
|
||||
from auth.google_auth import get_authenticated_google_service, GoogleAuthenticationError
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
from auth.oauth21_session_store import (
|
||||
get_auth_provider,
|
||||
get_oauth21_session_store,
|
||||
ensure_session_from_access_token,
|
||||
)
|
||||
from auth.oauth_config import is_oauth21_enabled, get_oauth_config
|
||||
from core.context import set_fastmcp_session_id
|
||||
from auth.scopes import (
|
||||
@@ -206,9 +210,52 @@ async def get_authenticated_google_service_oauth21(
|
||||
"""
|
||||
OAuth 2.1 authentication using the session store with security validation.
|
||||
"""
|
||||
provider = get_auth_provider()
|
||||
access_token = get_access_token()
|
||||
|
||||
if provider and access_token:
|
||||
token_email = None
|
||||
if getattr(access_token, "claims", None):
|
||||
token_email = access_token.claims.get("email")
|
||||
|
||||
resolved_email = token_email or auth_token_email or user_google_email
|
||||
if not resolved_email:
|
||||
raise GoogleAuthenticationError(
|
||||
"Authenticated user email could not be determined from access token."
|
||||
)
|
||||
|
||||
if auth_token_email and token_email and token_email != auth_token_email:
|
||||
raise GoogleAuthenticationError(
|
||||
"Access token email does not match authenticated session context."
|
||||
)
|
||||
|
||||
if token_email and user_google_email and token_email != user_google_email:
|
||||
raise GoogleAuthenticationError(
|
||||
f"Authenticated account {token_email} does not match requested user {user_google_email}."
|
||||
)
|
||||
|
||||
credentials = ensure_session_from_access_token(access_token, resolved_email, session_id)
|
||||
if not credentials:
|
||||
raise GoogleAuthenticationError(
|
||||
"Unable to build Google credentials from authenticated access token."
|
||||
)
|
||||
|
||||
scopes_available = set(credentials.scopes or [])
|
||||
if not scopes_available and getattr(access_token, "scopes", None):
|
||||
scopes_available = set(access_token.scopes)
|
||||
|
||||
if not all(scope in scopes_available for scope in required_scopes):
|
||||
raise GoogleAuthenticationError(
|
||||
f"OAuth credentials lack required scopes. Need: {required_scopes}, Have: {sorted(scopes_available)}"
|
||||
)
|
||||
|
||||
service = build(service_name, version, credentials=credentials)
|
||||
logger.info(f"[{tool_name}] Authenticated {service_name} for {resolved_email}")
|
||||
return service, resolved_email
|
||||
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
# Use the new validation method to ensure session can only access its own credentials
|
||||
# Use the validation method to ensure session can only access its own credentials
|
||||
credentials = store.get_credentials_with_validation(
|
||||
requested_user_email=user_google_email,
|
||||
session_id=session_id,
|
||||
@@ -222,13 +269,16 @@ async def get_authenticated_google_service_oauth21(
|
||||
f"You can only access credentials for your authenticated account."
|
||||
)
|
||||
|
||||
# Check scopes
|
||||
if not all(scope in credentials.scopes for scope in required_scopes):
|
||||
if not credentials.scopes:
|
||||
scopes_available = set(required_scopes)
|
||||
else:
|
||||
scopes_available = set(credentials.scopes)
|
||||
|
||||
if not all(scope in scopes_available for scope in required_scopes):
|
||||
raise GoogleAuthenticationError(
|
||||
f"OAuth 2.1 credentials lack required scopes. Need: {required_scopes}, Have: {credentials.scopes}"
|
||||
f"OAuth 2.1 credentials lack required scopes. Need: {required_scopes}, Have: {sorted(scopes_available)}"
|
||||
)
|
||||
|
||||
# Build service
|
||||
service = build(service_name, version, credentials=credentials)
|
||||
logger.info(f"[{tool_name}] Authenticated {service_name} for {user_google_email}")
|
||||
|
||||
@@ -731,4 +781,3 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user