almost there, working out session persistence
This commit is contained in:
@@ -164,6 +164,9 @@ def load_credentials_from_file(
|
||||
if creds_data.get("expiry"):
|
||||
try:
|
||||
expiry = datetime.fromisoformat(creds_data["expiry"])
|
||||
# Ensure timezone-naive datetime for Google auth library compatibility
|
||||
if expiry.tzinfo is not None:
|
||||
expiry = expiry.replace(tzinfo=None)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"Could not parse expiry time for {user_google_email}: {e}"
|
||||
|
||||
@@ -44,21 +44,46 @@ class MCPSessionMiddleware(BaseHTTPMiddleware):
|
||||
headers = dict(request.headers)
|
||||
session_id = extract_session_from_headers(headers)
|
||||
|
||||
# Try to get OAuth 2.1 auth context
|
||||
# Try to get OAuth 2.1 auth context from FastMCP
|
||||
auth_context = None
|
||||
user_email = None
|
||||
|
||||
# Check for FastMCP auth context
|
||||
if hasattr(request.state, "auth"):
|
||||
auth_context = request.state.auth
|
||||
# Extract user email from auth claims if available
|
||||
if hasattr(auth_context, 'claims') and auth_context.claims:
|
||||
user_email = auth_context.claims.get('email')
|
||||
|
||||
# Also check Authorization header for bearer tokens
|
||||
auth_header = headers.get("authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer ") and not user_email:
|
||||
try:
|
||||
import jwt
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
# Decode without verification to extract email
|
||||
claims = jwt.decode(token, options={"verify_signature": False})
|
||||
user_email = claims.get('email')
|
||||
if user_email:
|
||||
logger.debug(f"Extracted user email from JWT: {user_email}")
|
||||
except:
|
||||
pass
|
||||
|
||||
# Build session context
|
||||
if session_id or auth_context:
|
||||
if session_id or auth_context or user_email:
|
||||
# Create session ID from user email if not provided
|
||||
if not session_id and user_email:
|
||||
session_id = f"google_{user_email}"
|
||||
|
||||
session_context = SessionContext(
|
||||
session_id=session_id or (auth_context.session_id if auth_context else None),
|
||||
user_id=auth_context.user_id if auth_context else None,
|
||||
user_id=user_email or (auth_context.user_id if auth_context else None),
|
||||
auth_context=auth_context,
|
||||
request=request,
|
||||
metadata={
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"user_email": user_email,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -58,7 +58,8 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
# Calculate expiry
|
||||
expiry = None
|
||||
if "expires_in" in token_data:
|
||||
expiry = datetime.now(timezone.utc) + timedelta(seconds=token_data["expires_in"])
|
||||
# Google auth library expects timezone-naive datetime
|
||||
expiry = datetime.utcnow() + timedelta(seconds=token_data["expires_in"])
|
||||
|
||||
credentials = Credentials(
|
||||
token=token_data["access_token"],
|
||||
@@ -76,7 +77,8 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
# Otherwise, create minimal credentials with just the access token
|
||||
else:
|
||||
# Assume token is valid for 1 hour (typical for Google tokens)
|
||||
expiry = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
# Google auth library expects timezone-naive datetime
|
||||
expiry = datetime.utcnow() + timedelta(hours=1)
|
||||
|
||||
credentials = Credentials(
|
||||
token=access_token,
|
||||
@@ -129,7 +131,7 @@ def store_token_session(token_response: dict, user_email: str) -> str:
|
||||
client_id=_auth_provider.client_id,
|
||||
client_secret=_auth_provider.client_secret,
|
||||
scopes=token_response.get("scope", "").split() if token_response.get("scope") else None,
|
||||
expiry=datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600)),
|
||||
expiry=datetime.utcnow() + timedelta(seconds=token_response.get("expires_in", 3600)),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class SessionContext:
|
||||
auth_context: Optional[Any] = None
|
||||
request: Optional[Any] = None
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
@@ -36,7 +36,7 @@ class SessionContext:
|
||||
def set_session_context(context: Optional[SessionContext]):
|
||||
"""
|
||||
Set the current session context.
|
||||
|
||||
|
||||
Args:
|
||||
context: The session context to set
|
||||
"""
|
||||
@@ -50,10 +50,11 @@ def set_session_context(context: Optional[SessionContext]):
|
||||
def get_session_context() -> Optional[SessionContext]:
|
||||
"""
|
||||
Get the current session context.
|
||||
|
||||
|
||||
Returns:
|
||||
The current session context or None
|
||||
"""
|
||||
print('called get_session_context')
|
||||
return _current_session_context.get()
|
||||
|
||||
|
||||
@@ -65,22 +66,22 @@ def clear_session_context():
|
||||
class SessionContextManager:
|
||||
"""
|
||||
Context manager for temporarily setting session context.
|
||||
|
||||
|
||||
Usage:
|
||||
with SessionContextManager(session_context):
|
||||
# Code that needs access to session context
|
||||
pass
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, context: Optional[SessionContext]):
|
||||
self.context = context
|
||||
self.token = None
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
"""Set the session context."""
|
||||
self.token = _current_session_context.set(self.context)
|
||||
return self.context
|
||||
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Reset the session context."""
|
||||
if self.token:
|
||||
@@ -90,10 +91,10 @@ class SessionContextManager:
|
||||
def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Extract session ID from request headers.
|
||||
|
||||
|
||||
Args:
|
||||
headers: Request headers
|
||||
|
||||
|
||||
Returns:
|
||||
Session ID if found
|
||||
"""
|
||||
@@ -101,16 +102,16 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
|
||||
session_id = headers.get("mcp-session-id") or headers.get("Mcp-Session-Id")
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
|
||||
session_id = headers.get("x-session-id") or headers.get("X-Session-ID")
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
|
||||
# Try Authorization header for Bearer token
|
||||
auth_header = headers.get("authorization") or headers.get("Authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer "):
|
||||
# For now, we can't extract session from bearer token without the full context
|
||||
# This would need to be handled by the OAuth 2.1 middleware
|
||||
pass
|
||||
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user