cleanup redundant session stores

This commit is contained in:
Taylor Wilsdon
2025-08-03 11:12:58 -04:00
parent 065889fede
commit bb55feed83
4 changed files with 195 additions and 85 deletions

View File

@@ -24,10 +24,8 @@ class MCPOAuth21Bridge:
"""
def __init__(self):
# Map MCP transport session ID to OAuth 2.1 session info
self._mcp_to_oauth21_map: Dict[str, Dict[str, Any]] = {}
# Map OAuth 2.1 session ID to MCP transport session ID
self._oauth21_to_mcp_map: Dict[str, str] = {}
# Session mapping now handled by OAuth21SessionStore
pass
def link_sessions(
self,
@@ -37,32 +35,47 @@ class MCPOAuth21Bridge:
auth_context: Optional[Any] = None
):
"""
Link an MCP transport session with an OAuth 2.1 session.
Link an MCP transport session with an OAuth 2.1 session using OAuth21SessionStore.
Args:
mcp_session_id: MCP transport session ID
oauth21_session_id: OAuth 2.1 session ID
user_id: User identifier
user_id: User identifier (user email)
auth_context: OAuth 2.1 authentication context
"""
session_info = {
"oauth21_session_id": oauth21_session_id,
"user_id": user_id,
"auth_context": auth_context,
"linked_at": datetime.utcnow().isoformat(),
}
from auth.oauth21_session_store import get_oauth21_session_store
self._mcp_to_oauth21_map[mcp_session_id] = session_info
self._oauth21_to_mcp_map[oauth21_session_id] = mcp_session_id
logger.info(
f"Linked MCP session {mcp_session_id} with OAuth 2.1 session {oauth21_session_id} "
f"for user {user_id}"
)
if user_id: # user_id should be the user email
store = get_oauth21_session_store()
# Linking is handled by updating the mcp_session_id in the store
# We need to check if the user already has a session and update it
if store.has_session(user_id):
# Get existing session info to preserve it
existing_creds = store.get_credentials(user_id)
if existing_creds:
store.store_session(
user_email=user_id,
access_token=existing_creds.token,
refresh_token=existing_creds.refresh_token,
token_uri=existing_creds.token_uri,
client_id=existing_creds.client_id,
client_secret=existing_creds.client_secret,
scopes=existing_creds.scopes,
expiry=existing_creds.expiry,
session_id=oauth21_session_id,
mcp_session_id=mcp_session_id
)
logger.info(
f"Linked MCP session {mcp_session_id} with OAuth 2.1 session {oauth21_session_id} "
f"for user {user_id}"
)
else:
logger.warning(f"Cannot link sessions without user_id")
def get_oauth21_session(self, mcp_session_id: str) -> Optional[Dict[str, Any]]:
"""
Get OAuth 2.1 session info for an MCP transport session.
Get OAuth 2.1 session info for an MCP transport session from OAuth21SessionStore.
Args:
mcp_session_id: MCP transport session ID
@@ -70,11 +83,28 @@ class MCPOAuth21Bridge:
Returns:
OAuth 2.1 session information if linked
"""
return self._mcp_to_oauth21_map.get(mcp_session_id)
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
user_email = store.get_user_by_mcp_session(mcp_session_id)
if user_email:
credentials = store.get_credentials(user_email)
if credentials:
return {
"oauth21_session_id": f"oauth21_{user_email}",
"user_id": user_email,
"auth_context": {
"access_token": credentials.token,
"refresh_token": credentials.refresh_token,
"scopes": credentials.scopes
},
"linked_at": datetime.utcnow().isoformat(),
}
return None
def get_mcp_session(self, oauth21_session_id: str) -> Optional[str]:
"""
Get MCP transport session ID for an OAuth 2.1 session.
Get MCP transport session ID for an OAuth 2.1 session from OAuth21SessionStore.
Args:
oauth21_session_id: OAuth 2.1 session ID
@@ -82,21 +112,37 @@ class MCPOAuth21Bridge:
Returns:
MCP transport session ID if linked
"""
return self._oauth21_to_mcp_map.get(oauth21_session_id)
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
# Look through all sessions to find one with matching oauth21_session_id
stats = store.get_stats()
for user_email in stats["users"]:
# Try to match based on session_id pattern
if oauth21_session_id == f"oauth21_{user_email}":
# Get all MCP session mappings and find the one for this user
for mcp_session_id in stats["mcp_sessions"]:
if store.get_user_by_mcp_session(mcp_session_id) == user_email:
return mcp_session_id
return None
def unlink_mcp_session(self, mcp_session_id: str):
"""
Remove the link for an MCP transport session.
Remove the link for an MCP transport session using OAuth21SessionStore.
Args:
mcp_session_id: MCP transport session ID
"""
session_info = self._mcp_to_oauth21_map.pop(mcp_session_id, None)
if session_info:
oauth21_session_id = session_info.get("oauth21_session_id")
if oauth21_session_id:
self._oauth21_to_mcp_map.pop(oauth21_session_id, None)
logger.info(f"Unlinked MCP session {mcp_session_id}")
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
user_email = store.get_user_by_mcp_session(mcp_session_id)
if user_email:
# Remove the entire session since MCP bridge is responsible for the link
store.remove_session(user_email)
logger.info(f"Unlinked MCP session {mcp_session_id} for user {user_email}")
else:
logger.warning(f"No linked session found for MCP session {mcp_session_id}")
def set_session_context_for_mcp(self, mcp_session_id: str) -> bool:
"""
@@ -132,11 +178,15 @@ class MCPOAuth21Bridge:
return True
def get_stats(self) -> Dict[str, Any]:
"""Get bridge statistics."""
"""Get bridge statistics from OAuth21SessionStore."""
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
store_stats = store.get_stats()
return {
"linked_sessions": len(self._mcp_to_oauth21_map),
"mcp_sessions": list(self._mcp_to_oauth21_map.keys()),
"oauth21_sessions": list(self._oauth21_to_mcp_map.keys()),
"linked_sessions": store_stats["mcp_session_mappings"],
"mcp_sessions": store_stats["mcp_sessions"],
"oauth21_sessions": [f"oauth21_{user}" for user in store_stats["users"]],
}