implement WORKSPACE_MCP_STATELESS_MODE
This commit is contained in:
@@ -17,11 +17,6 @@ from google.oauth2.credentials import Credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Session Context Management (absorbed from session_context.py)
|
||||
# =============================================================================
|
||||
|
||||
# Context variable to store the current session information
|
||||
_current_session_context: contextvars.ContextVar[Optional['SessionContext']] = contextvars.ContextVar(
|
||||
'current_session_context',
|
||||
@@ -129,7 +124,7 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
|
||||
for user_email, session_info in store._sessions.items():
|
||||
if session_info.get("access_token") == token:
|
||||
return session_info.get("session_id") or f"bearer_{user_email}"
|
||||
|
||||
|
||||
# If no session found, create a temporary session ID from token hash
|
||||
# This allows header-based authentication to work with session context
|
||||
import hashlib
|
||||
@@ -146,21 +141,21 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
|
||||
class OAuth21SessionStore:
|
||||
"""
|
||||
Global store for OAuth 2.1 authenticated sessions.
|
||||
|
||||
|
||||
This store maintains a mapping of user emails to their OAuth 2.1
|
||||
authenticated credentials, allowing Google services to access them.
|
||||
It also maintains a mapping from FastMCP session IDs to user emails.
|
||||
|
||||
|
||||
Security: Sessions are bound to specific users and can only access
|
||||
their own credentials.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._sessions: Dict[str, Dict[str, Any]] = {}
|
||||
self._mcp_session_mapping: Dict[str, str] = {} # Maps FastMCP session ID -> user email
|
||||
self._session_auth_binding: Dict[str, str] = {} # Maps session ID -> authenticated user email (immutable)
|
||||
self._lock = RLock()
|
||||
|
||||
|
||||
def store_session(
|
||||
self,
|
||||
user_email: str,
|
||||
@@ -177,7 +172,7 @@ class OAuth21SessionStore:
|
||||
):
|
||||
"""
|
||||
Store OAuth 2.1 session information.
|
||||
|
||||
|
||||
Args:
|
||||
user_email: User's email address
|
||||
access_token: OAuth 2.1 access token
|
||||
@@ -204,9 +199,9 @@ class OAuth21SessionStore:
|
||||
"mcp_session_id": mcp_session_id,
|
||||
"issuer": issuer,
|
||||
}
|
||||
|
||||
|
||||
self._sessions[user_email] = session_info
|
||||
|
||||
|
||||
# Store MCP session mapping if provided
|
||||
if mcp_session_id:
|
||||
# Create immutable session binding (first binding wins, cannot be changed)
|
||||
@@ -217,23 +212,23 @@ class OAuth21SessionStore:
|
||||
# Security: Attempt to bind session to different user
|
||||
logger.error(f"SECURITY: Attempt to rebind session {mcp_session_id} from {self._session_auth_binding[mcp_session_id]} to {user_email}")
|
||||
raise ValueError(f"Session {mcp_session_id} is already bound to a different user")
|
||||
|
||||
|
||||
self._mcp_session_mapping[mcp_session_id] = user_email
|
||||
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id}, mcp_session_id: {mcp_session_id})")
|
||||
else:
|
||||
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id})")
|
||||
|
||||
|
||||
# Also create binding for the OAuth session ID
|
||||
if session_id and session_id not in self._session_auth_binding:
|
||||
self._session_auth_binding[session_id] = user_email
|
||||
|
||||
|
||||
def get_credentials(self, user_email: str) -> Optional[Credentials]:
|
||||
"""
|
||||
Get Google credentials for a user from OAuth 2.1 session.
|
||||
|
||||
|
||||
Args:
|
||||
user_email: User's email address
|
||||
|
||||
|
||||
Returns:
|
||||
Google Credentials object or None
|
||||
"""
|
||||
@@ -242,7 +237,7 @@ class OAuth21SessionStore:
|
||||
if not session_info:
|
||||
logger.debug(f"No OAuth 2.1 session found for {user_email}")
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
# Create Google credentials from session info
|
||||
credentials = Credentials(
|
||||
@@ -254,21 +249,21 @@ class OAuth21SessionStore:
|
||||
scopes=session_info.get("scopes", []),
|
||||
expiry=session_info.get("expiry"),
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"Retrieved OAuth 2.1 credentials for {user_email}")
|
||||
return credentials
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create credentials for {user_email}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_credentials_by_mcp_session(self, mcp_session_id: str) -> Optional[Credentials]:
|
||||
"""
|
||||
Get Google credentials using FastMCP session ID.
|
||||
|
||||
|
||||
Args:
|
||||
mcp_session_id: FastMCP session ID
|
||||
|
||||
|
||||
Returns:
|
||||
Google Credentials object or None
|
||||
"""
|
||||
@@ -278,28 +273,28 @@ class OAuth21SessionStore:
|
||||
if not user_email:
|
||||
logger.debug(f"No user mapping found for MCP session {mcp_session_id}")
|
||||
return None
|
||||
|
||||
|
||||
logger.debug(f"Found user {user_email} for MCP session {mcp_session_id}")
|
||||
return self.get_credentials(user_email)
|
||||
|
||||
|
||||
def get_credentials_with_validation(
|
||||
self,
|
||||
requested_user_email: str,
|
||||
self,
|
||||
requested_user_email: str,
|
||||
session_id: Optional[str] = None,
|
||||
auth_token_email: Optional[str] = None,
|
||||
allow_recent_auth: bool = False
|
||||
) -> Optional[Credentials]:
|
||||
"""
|
||||
Get Google credentials with session validation.
|
||||
|
||||
|
||||
This method ensures that a session can only access credentials for its
|
||||
authenticated user, preventing cross-account access.
|
||||
|
||||
|
||||
Args:
|
||||
requested_user_email: The email of the user whose credentials are requested
|
||||
session_id: The current session ID (MCP or OAuth session)
|
||||
auth_token_email: Email from the verified auth token (if available)
|
||||
|
||||
|
||||
Returns:
|
||||
Google Credentials object if validation passes, None otherwise
|
||||
"""
|
||||
@@ -314,7 +309,7 @@ class OAuth21SessionStore:
|
||||
return None
|
||||
# Token email matches, allow access
|
||||
return self.get_credentials(requested_user_email)
|
||||
|
||||
|
||||
# Priority 2: Check session binding
|
||||
if session_id:
|
||||
bound_user = self._session_auth_binding.get(session_id)
|
||||
@@ -327,7 +322,7 @@ class OAuth21SessionStore:
|
||||
return None
|
||||
# Session binding matches, allow access
|
||||
return self.get_credentials(requested_user_email)
|
||||
|
||||
|
||||
# Check if this is an MCP session
|
||||
mcp_user = self._mcp_session_mapping.get(session_id)
|
||||
if mcp_user:
|
||||
@@ -339,7 +334,7 @@ class OAuth21SessionStore:
|
||||
return None
|
||||
# MCP session matches, allow access
|
||||
return self.get_credentials(requested_user_email)
|
||||
|
||||
|
||||
# Special case: Allow access if user has recently authenticated (for clients that don't send tokens)
|
||||
# CRITICAL SECURITY: This is ONLY allowed in stdio mode, NEVER in OAuth 2.1 mode
|
||||
if allow_recent_auth and requested_user_email in self._sessions:
|
||||
@@ -356,45 +351,45 @@ class OAuth21SessionStore:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check transport mode: {e}")
|
||||
return None
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Allowing credential access for {requested_user_email} based on recent authentication "
|
||||
f"(stdio mode only - client not sending bearer token)"
|
||||
)
|
||||
return self.get_credentials(requested_user_email)
|
||||
|
||||
|
||||
# No session or token info available - deny access for security
|
||||
logger.warning(
|
||||
f"Credential access denied for {requested_user_email}: No valid session or token"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_user_by_mcp_session(self, mcp_session_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get user email by FastMCP session ID.
|
||||
|
||||
|
||||
Args:
|
||||
mcp_session_id: FastMCP session ID
|
||||
|
||||
|
||||
Returns:
|
||||
User email or None
|
||||
"""
|
||||
with self._lock:
|
||||
return self._mcp_session_mapping.get(mcp_session_id)
|
||||
|
||||
|
||||
def get_session_info(self, user_email: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get complete session information including issuer.
|
||||
|
||||
|
||||
Args:
|
||||
user_email: User's email address
|
||||
|
||||
|
||||
Returns:
|
||||
Session information dictionary or None
|
||||
"""
|
||||
with self._lock:
|
||||
return self._sessions.get(user_email)
|
||||
|
||||
|
||||
def remove_session(self, user_email: str):
|
||||
"""Remove session for a user."""
|
||||
with self._lock:
|
||||
@@ -403,10 +398,10 @@ class OAuth21SessionStore:
|
||||
session_info = self._sessions.get(user_email, {})
|
||||
mcp_session_id = session_info.get("mcp_session_id")
|
||||
session_id = session_info.get("session_id")
|
||||
|
||||
|
||||
# Remove from sessions
|
||||
del self._sessions[user_email]
|
||||
|
||||
|
||||
# Remove from MCP mapping if exists
|
||||
if mcp_session_id and mcp_session_id in self._mcp_session_mapping:
|
||||
del self._mcp_session_mapping[mcp_session_id]
|
||||
@@ -414,24 +409,24 @@ class OAuth21SessionStore:
|
||||
if mcp_session_id in self._session_auth_binding:
|
||||
del self._session_auth_binding[mcp_session_id]
|
||||
logger.info(f"Removed OAuth 2.1 session for {user_email} and MCP mapping for {mcp_session_id}")
|
||||
|
||||
|
||||
# Remove OAuth session binding if exists
|
||||
if session_id and session_id in self._session_auth_binding:
|
||||
del self._session_auth_binding[session_id]
|
||||
|
||||
|
||||
if not mcp_session_id:
|
||||
logger.info(f"Removed OAuth 2.1 session for {user_email}")
|
||||
|
||||
|
||||
def has_session(self, user_email: str) -> bool:
|
||||
"""Check if a user has an active session."""
|
||||
with self._lock:
|
||||
return user_email in self._sessions
|
||||
|
||||
|
||||
def has_mcp_session(self, mcp_session_id: str) -> bool:
|
||||
"""Check if an MCP session has an associated user session."""
|
||||
with self._lock:
|
||||
return mcp_session_id in self._mcp_session_mapping
|
||||
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get store statistics."""
|
||||
with self._lock:
|
||||
@@ -475,32 +470,32 @@ def get_auth_provider():
|
||||
def get_credentials_from_token(access_token: str, user_email: Optional[str] = None) -> Optional[Credentials]:
|
||||
"""
|
||||
Convert a bearer token to Google credentials.
|
||||
|
||||
|
||||
Args:
|
||||
access_token: The bearer token
|
||||
user_email: Optional user email for session lookup
|
||||
|
||||
|
||||
Returns:
|
||||
Google Credentials object or None
|
||||
"""
|
||||
if not _auth_provider:
|
||||
logger.error("Auth provider not configured")
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
|
||||
# If we have user_email, try to get credentials from store
|
||||
if user_email:
|
||||
credentials = store.get_credentials(user_email)
|
||||
if credentials and credentials.token == access_token:
|
||||
logger.debug(f"Found matching credentials from store for {user_email}")
|
||||
return credentials
|
||||
|
||||
|
||||
# Otherwise, create minimal credentials with just the access token
|
||||
# Assume token is valid for 1 hour (typical for Google tokens)
|
||||
expiry = datetime.utcnow() + timedelta(hours=1)
|
||||
|
||||
|
||||
credentials = Credentials(
|
||||
token=access_token,
|
||||
refresh_token=None,
|
||||
@@ -510,10 +505,10 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
scopes=None, # Will be populated from token claims if available
|
||||
expiry=expiry
|
||||
)
|
||||
|
||||
|
||||
logger.debug("Created Google credentials from bearer token")
|
||||
return credentials
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Google credentials from token: {e}")
|
||||
return None
|
||||
@@ -522,19 +517,19 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
def store_token_session(token_response: dict, user_email: str, mcp_session_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Store a token response in the session store.
|
||||
|
||||
|
||||
Args:
|
||||
token_response: OAuth token response from Google
|
||||
user_email: User's email address
|
||||
mcp_session_id: Optional FastMCP session ID to map to this user
|
||||
|
||||
|
||||
Returns:
|
||||
Session ID
|
||||
"""
|
||||
if not _auth_provider:
|
||||
logger.error("Auth provider not configured")
|
||||
return ""
|
||||
|
||||
|
||||
try:
|
||||
# Try to get FastMCP session ID from context if not provided
|
||||
if not mcp_session_id:
|
||||
@@ -545,10 +540,10 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
|
||||
logger.debug(f"Got FastMCP session ID from context: {mcp_session_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get FastMCP session from context: {e}")
|
||||
|
||||
|
||||
# Store session in OAuth21SessionStore
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
|
||||
session_id = f"google_{user_email}"
|
||||
store.store_session(
|
||||
user_email=user_email,
|
||||
@@ -563,14 +558,14 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
|
||||
mcp_session_id=mcp_session_id,
|
||||
issuer="https://accounts.google.com", # Add issuer for Google tokens
|
||||
)
|
||||
|
||||
|
||||
if mcp_session_id:
|
||||
logger.info(f"Stored token session for {user_email} with MCP session {mcp_session_id}")
|
||||
else:
|
||||
logger.info(f"Stored token session for {user_email}")
|
||||
|
||||
|
||||
return session_id
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store token session: {e}")
|
||||
return ""
|
||||
Reference in New Issue
Block a user