refac bridge and context modules, absorb essential functioanlity into oaauth21sessionstore
This commit is contained in:
@@ -2,18 +2,134 @@
|
||||
OAuth 2.1 Session Store for Google Services
|
||||
|
||||
This module provides a global store for OAuth 2.1 authenticated sessions
|
||||
that can be accessed by Google service decorators.
|
||||
that can be accessed by Google service decorators. It also includes
|
||||
session context management and credential conversion functionality.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
from typing import Dict, Optional, Any
|
||||
from threading import RLock
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Session Context Management (absorbed from session_context.py)
|
||||
# =============================================================================
|
||||
|
||||
# Context variable to store the current session information
|
||||
_current_session_context: contextvars.ContextVar[Optional['SessionContext']] = contextvars.ContextVar(
|
||||
'current_session_context',
|
||||
default=None
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
"""Container for session-related information."""
|
||||
session_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
auth_context: Optional[Any] = None
|
||||
request: Optional[Any] = None
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
def set_session_context(context: Optional[SessionContext]):
|
||||
"""
|
||||
Set the current session context.
|
||||
|
||||
Args:
|
||||
context: The session context to set
|
||||
"""
|
||||
_current_session_context.set(context)
|
||||
if context:
|
||||
logger.debug(f"Set session context: session_id={context.session_id}, user_id={context.user_id}")
|
||||
else:
|
||||
logger.debug("Cleared session context")
|
||||
|
||||
|
||||
def get_session_context() -> Optional[SessionContext]:
|
||||
"""
|
||||
Get the current session context.
|
||||
|
||||
Returns:
|
||||
The current session context or None
|
||||
"""
|
||||
return _current_session_context.get()
|
||||
|
||||
|
||||
def clear_session_context():
|
||||
"""Clear the current session context."""
|
||||
set_session_context(None)
|
||||
|
||||
|
||||
class SessionContextManager:
|
||||
"""
|
||||
Context manager for temporarily setting session context.
|
||||
|
||||
Usage:
|
||||
with SessionContextManager(session_context):
|
||||
# Code that needs access to session context
|
||||
pass
|
||||
"""
|
||||
|
||||
def __init__(self, context: Optional[SessionContext]):
|
||||
self.context = context
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Set the session context."""
|
||||
self.token = _current_session_context.set(self.context)
|
||||
return self.context
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Reset the session context."""
|
||||
if self.token:
|
||||
_current_session_context.reset(self.token)
|
||||
|
||||
|
||||
def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Extract session ID from request headers.
|
||||
|
||||
Args:
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Session ID if found
|
||||
"""
|
||||
# Try different header names
|
||||
session_id = headers.get("mcp-session-id") or headers.get("Mcp-Session-Id")
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
session_id = headers.get("x-session-id") or headers.get("X-Session-ID")
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
# Try Authorization header for Bearer token
|
||||
auth_header = headers.get("authorization") or headers.get("Authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer "):
|
||||
# For now, we can't extract session from bearer token without the full context
|
||||
# This would need to be handled by the OAuth 2.1 middleware
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OAuth21SessionStore - Main Session Management
|
||||
# =============================================================================
|
||||
|
||||
class OAuth21SessionStore:
|
||||
"""
|
||||
Global store for OAuth 2.1 authenticated sessions.
|
||||
@@ -191,4 +307,127 @@ _global_store = OAuth21SessionStore()
|
||||
|
||||
def get_oauth21_session_store() -> OAuth21SessionStore:
|
||||
"""Get the global OAuth 2.1 session store."""
|
||||
return _global_store
|
||||
return _global_store
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Google Credentials Bridge (absorbed from oauth21_google_bridge.py)
|
||||
# =============================================================================
|
||||
|
||||
# Global auth provider instance (set during server initialization)
|
||||
_auth_provider = None
|
||||
|
||||
|
||||
def set_auth_provider(provider):
|
||||
"""Set the global auth provider instance."""
|
||||
global _auth_provider
|
||||
_auth_provider = provider
|
||||
logger.info("OAuth 2.1 auth provider configured for Google credential bridging")
|
||||
|
||||
|
||||
def get_auth_provider():
|
||||
"""Get the global auth provider instance."""
|
||||
return _auth_provider
|
||||
|
||||
|
||||
def get_credentials_from_token(access_token: str, user_email: Optional[str] = None) -> Optional[Credentials]:
|
||||
"""
|
||||
Convert a bearer token to Google credentials.
|
||||
|
||||
Args:
|
||||
access_token: The bearer token
|
||||
user_email: Optional user email for session lookup
|
||||
|
||||
Returns:
|
||||
Google Credentials object or None
|
||||
"""
|
||||
if not _auth_provider:
|
||||
logger.error("Auth provider not configured")
|
||||
return None
|
||||
|
||||
try:
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
# If we have user_email, try to get credentials from store
|
||||
if user_email:
|
||||
credentials = store.get_credentials(user_email)
|
||||
if credentials and credentials.token == access_token:
|
||||
logger.debug(f"Found matching credentials from store for {user_email}")
|
||||
return credentials
|
||||
|
||||
# Otherwise, create minimal credentials with just the access token
|
||||
# Assume token is valid for 1 hour (typical for Google tokens)
|
||||
expiry = datetime.utcnow() + timedelta(hours=1)
|
||||
|
||||
credentials = Credentials(
|
||||
token=access_token,
|
||||
refresh_token=None,
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=_auth_provider.client_id,
|
||||
client_secret=_auth_provider.client_secret,
|
||||
scopes=None, # Will be populated from token claims if available
|
||||
expiry=expiry
|
||||
)
|
||||
|
||||
logger.debug("Created Google credentials from bearer token")
|
||||
return credentials
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Google credentials from token: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def store_token_session(token_response: dict, user_email: str, mcp_session_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Store a token response in the session store.
|
||||
|
||||
Args:
|
||||
token_response: OAuth token response from Google
|
||||
user_email: User's email address
|
||||
mcp_session_id: Optional FastMCP session ID to map to this user
|
||||
|
||||
Returns:
|
||||
Session ID
|
||||
"""
|
||||
if not _auth_provider:
|
||||
logger.error("Auth provider not configured")
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Try to get FastMCP session ID from context if not provided
|
||||
if not mcp_session_id:
|
||||
try:
|
||||
from core.context import get_fastmcp_session_id
|
||||
mcp_session_id = get_fastmcp_session_id()
|
||||
if mcp_session_id:
|
||||
logger.debug(f"Got FastMCP session ID from context: {mcp_session_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get FastMCP session from context: {e}")
|
||||
|
||||
# Store session in OAuth21SessionStore
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
session_id = f"google_{user_email}"
|
||||
store.store_session(
|
||||
user_email=user_email,
|
||||
access_token=token_response.get("access_token"),
|
||||
refresh_token=token_response.get("refresh_token"),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=_auth_provider.client_id,
|
||||
client_secret=_auth_provider.client_secret,
|
||||
scopes=token_response.get("scope", "").split() if token_response.get("scope") else None,
|
||||
expiry=datetime.utcnow() + timedelta(seconds=token_response.get("expires_in", 3600)),
|
||||
session_id=session_id,
|
||||
mcp_session_id=mcp_session_id,
|
||||
)
|
||||
|
||||
if mcp_session_id:
|
||||
logger.info(f"Stored token session for {user_email} with MCP session {mcp_session_id}")
|
||||
else:
|
||||
logger.info(f"Stored token session for {user_email}")
|
||||
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store token session: {e}")
|
||||
return ""
|
||||
Reference in New Issue
Block a user