Enhanced Session Management better guardrails

This commit is contained in:
Taylor Wilsdon
2025-08-03 15:51:04 -04:00
parent ff9b7ecd07
commit 71e2f1ba3e
4 changed files with 493 additions and 71 deletions

View File

@@ -119,9 +119,21 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
# Try Authorization header for Bearer token
auth_header = headers.get("authorization") or headers.get("Authorization")
if auth_header and auth_header.lower().startswith("bearer "):
# For now, we can't extract session from bearer token without the full context
# This would need to be handled by the OAuth 2.1 middleware
pass
# Extract bearer token and try to find associated session
token = auth_header[7:] # Remove "Bearer " prefix
if token:
# Look for a session that has this access token
# This requires scanning sessions, but bearer tokens should be unique
store = get_oauth21_session_store()
for user_email, session_info in store._sessions.items():
if session_info.get("access_token") == token:
return session_info.get("session_id") or f"bearer_{user_email}"
# If no session found, create a temporary session ID from token hash
# This allows header-based authentication to work with session context
import hashlib
token_hash = hashlib.sha256(token.encode()).hexdigest()[:8]
return f"bearer_token_{token_hash}"
return None
@@ -137,11 +149,15 @@ class OAuth21SessionStore:
This store maintains a mapping of user emails to their OAuth 2.1
authenticated credentials, allowing Google services to access them.
It also maintains a mapping from FastMCP session IDs to user emails.
Security: Sessions are bound to specific users and can only access
their own credentials.
"""
def __init__(self):
self._sessions: Dict[str, Dict[str, Any]] = {}
self._mcp_session_mapping: Dict[str, str] = {} # Maps FastMCP session ID -> user email
self._session_auth_binding: Dict[str, str] = {} # Maps session ID -> authenticated user email (immutable)
self._lock = RLock()
def store_session(
@@ -189,10 +205,23 @@ class OAuth21SessionStore:
# Store MCP session mapping if provided
if mcp_session_id:
# Create immutable session binding (first binding wins, cannot be changed)
if mcp_session_id not in self._session_auth_binding:
self._session_auth_binding[mcp_session_id] = user_email
logger.info(f"Created immutable session binding: {mcp_session_id} -> {user_email}")
elif self._session_auth_binding[mcp_session_id] != user_email:
# Security: Attempt to bind session to different user
logger.error(f"SECURITY: Attempt to rebind session {mcp_session_id} from {self._session_auth_binding[mcp_session_id]} to {user_email}")
raise ValueError(f"Session {mcp_session_id} is already bound to a different user")
self._mcp_session_mapping[mcp_session_id] = user_email
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id}, mcp_session_id: {mcp_session_id})")
else:
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id})")
# Also create binding for the OAuth session ID
if session_id and session_id not in self._session_auth_binding:
self._session_auth_binding[session_id] = user_email
def get_credentials(self, user_email: str) -> Optional[Credentials]:
"""
@@ -249,6 +278,79 @@ class OAuth21SessionStore:
logger.debug(f"Found user {user_email} for MCP session {mcp_session_id}")
return self.get_credentials(user_email)
def get_credentials_with_validation(
self,
requested_user_email: str,
session_id: Optional[str] = None,
auth_token_email: Optional[str] = None,
allow_recent_auth: bool = False
) -> Optional[Credentials]:
"""
Get Google credentials with session validation.
This method ensures that a session can only access credentials for its
authenticated user, preventing cross-account access.
Args:
requested_user_email: The email of the user whose credentials are requested
session_id: The current session ID (MCP or OAuth session)
auth_token_email: Email from the verified auth token (if available)
Returns:
Google Credentials object if validation passes, None otherwise
"""
with self._lock:
# Priority 1: Check auth token email (most secure, from verified JWT)
if auth_token_email:
if auth_token_email != requested_user_email:
logger.error(
f"SECURITY VIOLATION: Token for {auth_token_email} attempted to access "
f"credentials for {requested_user_email}"
)
return None
# Token email matches, allow access
return self.get_credentials(requested_user_email)
# Priority 2: Check session binding
if session_id:
bound_user = self._session_auth_binding.get(session_id)
if bound_user:
if bound_user != requested_user_email:
logger.error(
f"SECURITY VIOLATION: Session {session_id} (bound to {bound_user}) "
f"attempted to access credentials for {requested_user_email}"
)
return None
# Session binding matches, allow access
return self.get_credentials(requested_user_email)
# Check if this is an MCP session
mcp_user = self._mcp_session_mapping.get(session_id)
if mcp_user:
if mcp_user != requested_user_email:
logger.error(
f"SECURITY VIOLATION: MCP session {session_id} (user {mcp_user}) "
f"attempted to access credentials for {requested_user_email}"
)
return None
# MCP session matches, allow access
return self.get_credentials(requested_user_email)
# Special case: Allow access if user has recently authenticated (for clients that don't send tokens)
# This is a temporary workaround for MCP clients that complete OAuth but don't send bearer tokens
if allow_recent_auth and requested_user_email in self._sessions:
logger.info(
f"Allowing credential access for {requested_user_email} based on recent authentication "
f"(client not sending bearer token)"
)
return self.get_credentials(requested_user_email)
# No session or token info available - deny access for security
logger.warning(
f"Credential access denied for {requested_user_email}: No valid session or token"
)
return None
def get_user_by_mcp_session(self, mcp_session_id: str) -> Optional[str]:
"""
Get user email by FastMCP session ID.
@@ -266,9 +368,10 @@ class OAuth21SessionStore:
"""Remove session for a user."""
with self._lock:
if user_email in self._sessions:
# Get MCP session ID if exists to clean up mapping
# Get session IDs to clean up mappings
session_info = self._sessions.get(user_email, {})
mcp_session_id = session_info.get("mcp_session_id")
session_id = session_info.get("session_id")
# Remove from sessions
del self._sessions[user_email]
@@ -276,8 +379,16 @@ class OAuth21SessionStore:
# Remove from MCP mapping if exists
if mcp_session_id and mcp_session_id in self._mcp_session_mapping:
del self._mcp_session_mapping[mcp_session_id]
# Also remove from auth binding
if mcp_session_id in self._session_auth_binding:
del self._session_auth_binding[mcp_session_id]
logger.info(f"Removed OAuth 2.1 session for {user_email} and MCP mapping for {mcp_session_id}")
else:
# Remove OAuth session binding if exists
if session_id and session_id in self._session_auth_binding:
del self._session_auth_binding[session_id]
if not mcp_session_id:
logger.info(f"Removed OAuth 2.1 session for {user_email}")
def has_session(self, user_email: str) -> bool: