almost there, working out session persistence

This commit is contained in:
Taylor Wilsdon
2025-08-02 15:40:23 -04:00
parent 06ef1223dd
commit 9470a41dde
6 changed files with 210 additions and 49 deletions

View File

@@ -27,7 +27,7 @@ class SessionContext:
auth_context: Optional[Any] = None
request: Optional[Any] = None
metadata: Dict[str, Any] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
@@ -36,7 +36,7 @@ class SessionContext:
def set_session_context(context: Optional[SessionContext]):
"""
Set the current session context.
Args:
context: The session context to set
"""
@@ -50,10 +50,11 @@ def set_session_context(context: Optional[SessionContext]):
def get_session_context() -> Optional[SessionContext]:
"""
Get the current session context.
Returns:
The current session context or None
"""
print('called get_session_context')
return _current_session_context.get()
@@ -65,22 +66,22 @@ def clear_session_context():
class SessionContextManager:
"""
Context manager for temporarily setting session context.
Usage:
with SessionContextManager(session_context):
# Code that needs access to session context
pass
"""
def __init__(self, context: Optional[SessionContext]):
self.context = context
self.token = None
def __enter__(self):
"""Set the session context."""
self.token = _current_session_context.set(self.context)
return self.context
def __exit__(self, exc_type, exc_val, exc_tb):
"""Reset the session context."""
if self.token:
@@ -90,10 +91,10 @@ class SessionContextManager:
def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
"""
Extract session ID from request headers.
Args:
headers: Request headers
Returns:
Session ID if found
"""
@@ -101,16 +102,16 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
session_id = headers.get("mcp-session-id") or headers.get("Mcp-Session-Id")
if session_id:
return session_id
session_id = headers.get("x-session-id") or headers.get("X-Session-ID")
if session_id:
return session_id
# Try Authorization header for Bearer token
auth_header = headers.get("authorization") or headers.get("Authorization")
if auth_header and auth_header.lower().startswith("bearer "):
# For now, we can't extract session from bearer token without the full context
# This would need to be handled by the OAuth 2.1 middleware
pass
return None