Enhanced Session Management better guardrails

This commit is contained in:
Taylor Wilsdon
2025-08-03 15:51:04 -04:00
parent ff9b7ecd07
commit 71e2f1ba3e
4 changed files with 493 additions and 71 deletions

View File

@@ -23,25 +23,103 @@ from auth.oauth21_session_store import get_session_context
# OAuth 2.1 integration is now handled by FastMCP auth
OAUTH21_INTEGRATION_AVAILABLE = True
async def _extract_and_verify_bearer_token() -> tuple[Optional[str], Optional[str]]:
"""
Extract and verify bearer token from HTTP headers.
Returns:
Tuple of (user_email, verified_token) if valid, (None, None) if invalid or not found
"""
try:
from fastmcp.server.dependencies import get_http_headers
headers = get_http_headers()
if not headers:
logger.debug("No HTTP headers available for bearer token extraction")
return None, None
# Look for Authorization header
auth_header = headers.get("authorization") or headers.get("Authorization")
if not auth_header:
logger.debug("No Authorization header found in request")
return None, None
if not auth_header.lower().startswith("bearer "):
logger.debug(f"Authorization header present but not Bearer token: {auth_header[:20]}...")
return None, None
# Extract token
token = auth_header[7:] # Remove "Bearer " prefix
if not token:
logger.debug("Empty bearer token found")
return None, None
logger.debug(f"Found bearer token in Authorization header: {token[:20]}...")
# Verify token using GoogleWorkspaceAuthProvider
try:
from core.server import get_auth_provider
auth_provider = get_auth_provider()
if not auth_provider:
logger.debug("No auth provider available for token verification")
return None, None
# Verify the token
access_token = await auth_provider.verify_token(token)
if not access_token:
logger.debug("Bearer token verification failed")
return None, None
# Extract user email from verified token
user_email = access_token.claims.get("email")
if not user_email:
logger.debug("No email claim found in verified token")
return None, None
logger.info(f"Successfully verified bearer token for user: {user_email}")
return user_email, token
except Exception as e:
logger.error(f"Error verifying bearer token: {e}")
return None, None
except Exception as e:
logger.debug(f"Error extracting bearer token from headers: {e}")
return None, None
async def get_authenticated_google_service_oauth21(
service_name: str,
version: str,
tool_name: str,
user_google_email: str,
required_scopes: List[str],
session_id: Optional[str] = None,
auth_token_email: Optional[str] = None,
allow_recent_auth: bool = False,
) -> tuple[Any, str]:
"""
OAuth 2.1 authentication using the session store.
OAuth 2.1 authentication using the session store with security validation.
"""
from auth.oauth21_session_store import get_oauth21_session_store
from googleapiclient.discovery import build
store = get_oauth21_session_store()
credentials = store.get_credentials(user_google_email)
# Use the new validation method to ensure session can only access its own credentials
credentials = store.get_credentials_with_validation(
requested_user_email=user_google_email,
session_id=session_id,
auth_token_email=auth_token_email,
allow_recent_auth=allow_recent_auth
)
if not credentials:
from auth.google_auth import GoogleAuthenticationError
raise GoogleAuthenticationError(f"No OAuth 2.1 credentials found for {user_google_email}")
raise GoogleAuthenticationError(
f"Access denied: Cannot retrieve credentials for {user_google_email}. "
f"You can only access credentials for your authenticated account."
)
# Check scopes
if not all(scope in credentials.scopes for scope in required_scopes):
@@ -294,8 +372,9 @@ def require_google_service(
# Check if we have OAuth 2.1 credentials for this user
session_ctx = None
auth_token_email = None
# Try to get FastMCP session ID (for future use)
# Try to get FastMCP session ID and auth info
mcp_session_id = None
session_ctx = None
try:
@@ -309,12 +388,18 @@ def require_google_service(
from core.context import set_fastmcp_session_id
set_fastmcp_session_id(mcp_session_id)
# Extract authenticated email from auth context if available
if hasattr(fastmcp_ctx, 'auth') and fastmcp_ctx.auth:
if hasattr(fastmcp_ctx.auth, 'claims') and fastmcp_ctx.auth.claims:
auth_token_email = fastmcp_ctx.auth.claims.get('email')
logger.debug(f"[{tool_name}] Got authenticated email from token: {auth_token_email}")
# Create session context using FastMCP session ID
from auth.oauth21_session_store import SessionContext
session_ctx = SessionContext(
session_id=mcp_session_id,
user_id=user_google_email,
metadata={"fastmcp_session_id": mcp_session_id, "user_email": user_google_email}
user_id=auth_token_email or user_google_email,
metadata={"fastmcp_session_id": mcp_session_id, "user_email": user_google_email, "auth_email": auth_token_email}
)
except Exception as e:
logger.debug(f"[{tool_name}] Could not get FastMCP context: {e}")
@@ -323,20 +408,151 @@ def require_google_service(
if not session_ctx and OAUTH21_INTEGRATION_AVAILABLE:
session_ctx = get_session_context()
# Also check if user has credentials in OAuth 2.1 store
has_oauth21_creds = False
# Check if the CURRENT REQUEST is authenticated
is_authenticated_request = False
authenticated_user = None
bearer_token = None
if OAUTH21_INTEGRATION_AVAILABLE:
# Check if we have an authenticated FastMCP context
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
has_oauth21_creds = store.has_session(user_google_email)
from fastmcp.server.dependencies import get_context
ctx = get_context()
if ctx and hasattr(ctx, 'auth') and ctx.auth:
# We have authentication info from FastMCP
is_authenticated_request = True
if hasattr(ctx.auth, 'claims'):
authenticated_user = ctx.auth.claims.get('email')
logger.debug(f"[{tool_name}] Authenticated via FastMCP context: {authenticated_user}")
except Exception:
pass
# If FastMCP context didn't provide authentication, check HTTP headers directly
if not is_authenticated_request:
logger.debug(f"[{tool_name}] FastMCP context has no auth, checking HTTP headers for bearer token")
header_user, header_token = await _extract_and_verify_bearer_token()
if header_user and header_token:
is_authenticated_request = True
authenticated_user = header_user
bearer_token = header_token
logger.info(f"[{tool_name}] Authenticated via HTTP bearer token: {authenticated_user}")
# Create session binding for this bearer token authenticated request
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
# Create a session for this bearer token authentication
session_id = f"bearer_{authenticated_user}_{header_token[:8]}"
store.store_session(
user_email=authenticated_user,
access_token=header_token,
session_id=session_id,
mcp_session_id=mcp_session_id
)
logger.debug(f"[{tool_name}] Created session binding for bearer token auth: {session_id}")
except Exception as e:
logger.warning(f"[{tool_name}] Could not create session binding for bearer token: {e}")
else:
logger.debug(f"[{tool_name}] No valid bearer token found in HTTP headers")
# Fallback: Check other authentication indicators
if not is_authenticated_request:
# Check if MCP session is bound to a user
mcp_user = None
if mcp_session_id:
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
mcp_user = store.get_user_by_mcp_session(mcp_session_id)
except Exception:
pass
# TEMPORARY: Check if user has recently authenticated (for clients that don't send bearer tokens)
# This still enforces that users can only access their own credentials
has_recent_auth = False
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
has_recent_auth = store.has_session(user_google_email)
if has_recent_auth:
logger.info(f"[{tool_name}] User {user_google_email} has recent auth session (client not sending bearer token)")
except Exception:
pass
is_authenticated_request = (
auth_token_email is not None or
(session_ctx is not None and session_ctx.user_id) or
mcp_user is not None or
has_recent_auth # Allow if user has authenticated (still validates in OAuth21SessionStore)
)
session_id_for_log = mcp_session_id if mcp_session_id else (session_ctx.session_id if session_ctx else 'None')
logger.info(f"[{tool_name}] OAuth 2.1 available: {OAUTH21_INTEGRATION_AVAILABLE}, FastMCP Session ID: {mcp_session_id}, Session context: {session_ctx}, Session ID: {session_id_for_log}, Has OAuth21 creds: {has_oauth21_creds}")
auth_method = "none"
if authenticated_user:
if bearer_token:
auth_method = "bearer_token"
elif auth_token_email:
auth_method = "fastmcp_context"
else:
auth_method = "session"
if OAUTH21_INTEGRATION_AVAILABLE and (session_ctx or has_oauth21_creds):
logger.info(f"[{tool_name}] Authentication Status:"
f" Method={auth_method},"
f" OAuth21={OAUTH21_INTEGRATION_AVAILABLE},"
f" Authenticated={is_authenticated_request},"
f" User={authenticated_user or 'none'},"
f" SessionID={session_id_for_log},"
f" MCPSessionID={mcp_session_id or 'none'}")
# CRITICAL SECURITY: Check if OAuth 2.1 is enabled AND we're in HTTP mode
from core.config import get_transport_mode
transport_mode = get_transport_mode()
# Check if OAuth 2.1 provider is configured (not just transport mode)
oauth21_enabled = False
try:
from core.server import get_auth_provider
auth_provider = get_auth_provider()
oauth21_enabled = auth_provider is not None
except Exception:
pass
if transport_mode == "streamable-http" and oauth21_enabled:
# OAuth 2.1 is enabled - REQUIRE authentication, no fallback to files
if not is_authenticated_request:
logger.error(f"[{tool_name}] SECURITY: Unauthenticated request denied in OAuth 2.1 mode")
raise Exception(
"Authentication required. This server is configured with OAuth 2.1 authentication. "
"Please authenticate first using the OAuth flow before accessing resources."
)
# Additional security: Verify the authenticated user matches the requested user
# Only enforce this if we have a verified authenticated user from a token
if authenticated_user and authenticated_user != user_google_email:
logger.warning(
f"[{tool_name}] User mismatch - token authenticated as {authenticated_user} "
f"but requesting resources for {user_google_email}"
)
# The OAuth21SessionStore will handle the actual validation
# Must use OAuth 2.1 authentication
logger.info(f"[{tool_name}] Using OAuth 2.1 authentication (required for OAuth 2.1 mode)")
# Check if we're allowing recent auth (for clients that don't send bearer tokens)
allow_recent = not authenticated_user and not auth_token_email and not mcp_session_id
service, actual_user_email = await get_authenticated_google_service_oauth21(
service_name=service_name,
version=service_version,
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=resolved_scopes,
session_id=mcp_session_id or (session_ctx.session_id if session_ctx else None),
auth_token_email=auth_token_email or authenticated_user, # Pass authenticated user
allow_recent_auth=allow_recent, # Allow recent auth for clients that don't send tokens
)
elif OAUTH21_INTEGRATION_AVAILABLE and is_authenticated_request:
# In other modes, use OAuth 2.1 if available
logger.info(f"[{tool_name}] Using OAuth 2.1 authentication")
service, actual_user_email = await get_authenticated_google_service_oauth21(
service_name=service_name,
@@ -344,19 +560,25 @@ def require_google_service(
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=resolved_scopes,
session_id=mcp_session_id or (session_ctx.session_id if session_ctx else None),
auth_token_email=auth_token_email,
)
else:
# Fall back to legacy authentication
session_id_for_legacy = mcp_session_id if mcp_session_id else (session_ctx.session_id if session_ctx else None)
logger.info(f"[{tool_name}] Calling get_authenticated_google_service with session_id_for_legacy: {session_id_for_legacy}")
service, actual_user_email = await get_authenticated_google_service(
service_name=service_name,
version=service_version,
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=resolved_scopes,
session_id=session_id_for_legacy,
)
# Fall back to legacy authentication ONLY in stdio mode
if transport_mode == "stdio":
session_id_for_legacy = mcp_session_id if mcp_session_id else (session_ctx.session_id if session_ctx else None)
logger.info(f"[{tool_name}] Using legacy authentication (stdio mode)")
service, actual_user_email = await get_authenticated_google_service(
service_name=service_name,
version=service_version,
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=resolved_scopes,
session_id=session_id_for_legacy,
)
else:
logger.error(f"[{tool_name}] No authentication available in {transport_mode} mode")
raise Exception(f"Authentication not available in {transport_mode} mode")
if cache_enabled:
cache_key = _get_cache_key(user_google_email, service_name, service_version, resolved_scopes)
@@ -436,8 +658,37 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
try:
tool_name = func.__name__
# Try OAuth 2.1 integration first if available
if OAUTH21_INTEGRATION_AVAILABLE and get_session_context():
# Check if OAuth 2.1 is enabled AND we're in HTTP mode
from core.config import get_transport_mode
transport_mode = get_transport_mode()
# Check if OAuth 2.1 provider is configured
oauth21_enabled = False
try:
from core.server import get_auth_provider
auth_provider = get_auth_provider()
oauth21_enabled = auth_provider is not None
except Exception:
pass
# In OAuth 2.1 mode, require authentication
if transport_mode == "streamable-http" and oauth21_enabled:
if not (OAUTH21_INTEGRATION_AVAILABLE and get_session_context()):
logger.error(f"[{tool_name}] SECURITY: Unauthenticated request denied in OAuth 2.1 mode")
raise Exception(
"Authentication required. This server is configured with OAuth 2.1 authentication. "
"Please authenticate first using the OAuth flow before accessing resources."
)
logger.debug(f"OAuth 2.1 authentication for {tool_name} ({service_type})")
service, _ = await get_authenticated_google_service_oauth21(
service_name=service_name,
version=service_version,
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=resolved_scopes,
)
elif OAUTH21_INTEGRATION_AVAILABLE and get_session_context():
logger.debug(f"Attempting OAuth 2.1 authentication for {tool_name} ({service_type})")
service, _ = await get_authenticated_google_service_oauth21(
service_name=service_name,
@@ -447,14 +698,18 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
required_scopes=resolved_scopes,
)
else:
# Fall back to legacy authentication
service, _ = await get_authenticated_google_service(
service_name=service_name,
version=service_version,
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=resolved_scopes,
)
# Fall back to legacy authentication ONLY in stdio mode
if transport_mode == "stdio":
service, _ = await get_authenticated_google_service(
service_name=service_name,
version=service_version,
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=resolved_scopes,
)
else:
logger.error(f"[{tool_name}] No authentication available in {transport_mode} mode")
raise Exception(f"Authentication not available in {transport_mode} mode")
# Inject service with specified parameter name
kwargs[param_name] = service