almost there, working out session persistence

This commit is contained in:
Taylor Wilsdon
2025-08-02 15:40:23 -04:00
parent 06ef1223dd
commit 9470a41dde
6 changed files with 210 additions and 49 deletions

View File

@@ -164,6 +164,9 @@ def load_credentials_from_file(
if creds_data.get("expiry"):
try:
expiry = datetime.fromisoformat(creds_data["expiry"])
# Ensure timezone-naive datetime for Google auth library compatibility
if expiry.tzinfo is not None:
expiry = expiry.replace(tzinfo=None)
except (ValueError, TypeError) as e:
logger.warning(
f"Could not parse expiry time for {user_google_email}: {e}"

View File

@@ -44,21 +44,46 @@ class MCPSessionMiddleware(BaseHTTPMiddleware):
headers = dict(request.headers)
session_id = extract_session_from_headers(headers)
# Try to get OAuth 2.1 auth context
# Try to get OAuth 2.1 auth context from FastMCP
auth_context = None
user_email = None
# Check for FastMCP auth context
if hasattr(request.state, "auth"):
auth_context = request.state.auth
# Extract user email from auth claims if available
if hasattr(auth_context, 'claims') and auth_context.claims:
user_email = auth_context.claims.get('email')
# Also check Authorization header for bearer tokens
auth_header = headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer ") and not user_email:
try:
import jwt
token = auth_header[7:] # Remove "Bearer " prefix
# Decode without verification to extract email
claims = jwt.decode(token, options={"verify_signature": False})
user_email = claims.get('email')
if user_email:
logger.debug(f"Extracted user email from JWT: {user_email}")
except:
pass
# Build session context
if session_id or auth_context:
if session_id or auth_context or user_email:
# Create session ID from user email if not provided
if not session_id and user_email:
session_id = f"google_{user_email}"
session_context = SessionContext(
session_id=session_id or (auth_context.session_id if auth_context else None),
user_id=auth_context.user_id if auth_context else None,
user_id=user_email or (auth_context.user_id if auth_context else None),
auth_context=auth_context,
request=request,
metadata={
"path": request.url.path,
"method": request.method,
"user_email": user_email,
}
)

View File

@@ -58,7 +58,8 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
# Calculate expiry
expiry = None
if "expires_in" in token_data:
expiry = datetime.now(timezone.utc) + timedelta(seconds=token_data["expires_in"])
# Google auth library expects timezone-naive datetime
expiry = datetime.utcnow() + timedelta(seconds=token_data["expires_in"])
credentials = Credentials(
token=token_data["access_token"],
@@ -76,7 +77,8 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
# Otherwise, create minimal credentials with just the access token
else:
# Assume token is valid for 1 hour (typical for Google tokens)
expiry = datetime.now(timezone.utc) + timedelta(hours=1)
# Google auth library expects timezone-naive datetime
expiry = datetime.utcnow() + timedelta(hours=1)
credentials = Credentials(
token=access_token,
@@ -129,7 +131,7 @@ def store_token_session(token_response: dict, user_email: str) -> str:
client_id=_auth_provider.client_id,
client_secret=_auth_provider.client_secret,
scopes=token_response.get("scope", "").split() if token_response.get("scope") else None,
expiry=datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600)),
expiry=datetime.utcnow() + timedelta(seconds=token_response.get("expires_in", 3600)),
session_id=session_id,
)

View File

@@ -27,7 +27,7 @@ class SessionContext:
auth_context: Optional[Any] = None
request: Optional[Any] = None
metadata: Dict[str, Any] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
@@ -36,7 +36,7 @@ class SessionContext:
def set_session_context(context: Optional[SessionContext]):
"""
Set the current session context.
Args:
context: The session context to set
"""
@@ -50,10 +50,11 @@ def set_session_context(context: Optional[SessionContext]):
def get_session_context() -> Optional[SessionContext]:
"""
Get the current session context.
Returns:
The current session context or None
"""
print('called get_session_context')
return _current_session_context.get()
@@ -65,22 +66,22 @@ def clear_session_context():
class SessionContextManager:
"""
Context manager for temporarily setting session context.
Usage:
with SessionContextManager(session_context):
# Code that needs access to session context
pass
"""
def __init__(self, context: Optional[SessionContext]):
self.context = context
self.token = None
def __enter__(self):
"""Set the session context."""
self.token = _current_session_context.set(self.context)
return self.context
def __exit__(self, exc_type, exc_val, exc_tb):
"""Reset the session context."""
if self.token:
@@ -90,10 +91,10 @@ class SessionContextManager:
def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
"""
Extract session ID from request headers.
Args:
headers: Request headers
Returns:
Session ID if found
"""
@@ -101,16 +102,16 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
session_id = headers.get("mcp-session-id") or headers.get("Mcp-Session-Id")
if session_id:
return session_id
session_id = headers.get("x-session-id") or headers.get("X-Session-ID")
if session_id:
return session_id
# Try Authorization header for Bearer token
auth_header = headers.get("authorization") or headers.get("Authorization")
if auth_header and auth_header.lower().startswith("bearer "):
# For now, we can't extract session from bearer token without the full context
# This would need to be handled by the OAuth 2.1 middleware
pass
return None