apply ruff formatting
This commit is contained in:
@@ -34,7 +34,9 @@ def _normalize_expiry_to_naive_utc(expiry: Optional[Any]) -> Optional[datetime]:
|
||||
try:
|
||||
return expiry.astimezone(timezone.utc).replace(tzinfo=None)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.debug("Failed to normalize aware expiry; returning without tzinfo")
|
||||
logger.debug(
|
||||
"Failed to normalize aware expiry; returning without tzinfo"
|
||||
)
|
||||
return expiry.replace(tzinfo=None)
|
||||
return expiry # Already naive; assumed to represent UTC
|
||||
|
||||
@@ -51,15 +53,15 @@ def _normalize_expiry_to_naive_utc(expiry: Optional[Any]) -> Optional[datetime]:
|
||||
|
||||
|
||||
# Context variable to store the current session information
|
||||
_current_session_context: contextvars.ContextVar[Optional['SessionContext']] = contextvars.ContextVar(
|
||||
'current_session_context',
|
||||
default=None
|
||||
_current_session_context: contextvars.ContextVar[Optional["SessionContext"]] = (
|
||||
contextvars.ContextVar("current_session_context", default=None)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
"""Container for session-related information."""
|
||||
|
||||
session_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
auth_context: Optional[Any] = None
|
||||
@@ -81,7 +83,9 @@ def set_session_context(context: Optional[SessionContext]):
|
||||
"""
|
||||
_current_session_context.set(context)
|
||||
if context:
|
||||
logger.debug(f"Set session context: session_id={context.session_id}, user_id={context.user_id}")
|
||||
logger.debug(
|
||||
f"Set session context: session_id={context.session_id}, user_id={context.user_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug("Cleared session context")
|
||||
|
||||
@@ -161,6 +165,7 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
|
||||
# If no session found, create a temporary session ID from token hash
|
||||
# This allows header-based authentication to work with session context
|
||||
import hashlib
|
||||
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()[:8]
|
||||
return f"bearer_token_{token_hash}"
|
||||
|
||||
@@ -171,6 +176,7 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
|
||||
# OAuth21SessionStore - Main Session Management
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class OAuth21SessionStore:
|
||||
"""
|
||||
Global store for OAuth 2.1 authenticated sessions.
|
||||
@@ -185,8 +191,12 @@ class OAuth21SessionStore:
|
||||
|
||||
def __init__(self):
|
||||
self._sessions: Dict[str, Dict[str, Any]] = {}
|
||||
self._mcp_session_mapping: Dict[str, str] = {} # Maps FastMCP session ID -> user email
|
||||
self._session_auth_binding: Dict[str, str] = {} # Maps session ID -> authenticated user email (immutable)
|
||||
self._mcp_session_mapping: Dict[
|
||||
str, str
|
||||
] = {} # Maps FastMCP session ID -> user email
|
||||
self._session_auth_binding: Dict[
|
||||
str, str
|
||||
] = {} # Maps session ID -> authenticated user email (immutable)
|
||||
self._oauth_states: Dict[str, Dict[str, Any]] = {}
|
||||
self._lock = RLock()
|
||||
|
||||
@@ -258,7 +268,9 @@ class OAuth21SessionStore:
|
||||
state_info = self._oauth_states.get(state)
|
||||
|
||||
if not state_info:
|
||||
logger.error("SECURITY: OAuth callback received unknown or expired state")
|
||||
logger.error(
|
||||
"SECURITY: OAuth callback received unknown or expired state"
|
||||
)
|
||||
raise ValueError("Invalid or expired OAuth state parameter")
|
||||
|
||||
bound_session = state_info.get("session_id")
|
||||
@@ -332,16 +344,26 @@ class OAuth21SessionStore:
|
||||
# Create immutable session binding (first binding wins, cannot be changed)
|
||||
if mcp_session_id not in self._session_auth_binding:
|
||||
self._session_auth_binding[mcp_session_id] = user_email
|
||||
logger.info(f"Created immutable session binding: {mcp_session_id} -> {user_email}")
|
||||
logger.info(
|
||||
f"Created immutable session binding: {mcp_session_id} -> {user_email}"
|
||||
)
|
||||
elif self._session_auth_binding[mcp_session_id] != user_email:
|
||||
# Security: Attempt to bind session to different user
|
||||
logger.error(f"SECURITY: Attempt to rebind session {mcp_session_id} from {self._session_auth_binding[mcp_session_id]} to {user_email}")
|
||||
raise ValueError(f"Session {mcp_session_id} is already bound to a different user")
|
||||
logger.error(
|
||||
f"SECURITY: Attempt to rebind session {mcp_session_id} from {self._session_auth_binding[mcp_session_id]} to {user_email}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Session {mcp_session_id} is already bound to a different user"
|
||||
)
|
||||
|
||||
self._mcp_session_mapping[mcp_session_id] = user_email
|
||||
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id}, mcp_session_id: {mcp_session_id})")
|
||||
logger.info(
|
||||
f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id}, mcp_session_id: {mcp_session_id})"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id})")
|
||||
logger.info(
|
||||
f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id})"
|
||||
)
|
||||
|
||||
# Also create binding for the OAuth session ID
|
||||
if session_id and session_id not in self._session_auth_binding:
|
||||
@@ -382,7 +404,9 @@ class OAuth21SessionStore:
|
||||
logger.error(f"Failed to create credentials for {user_email}: {e}")
|
||||
return None
|
||||
|
||||
def get_credentials_by_mcp_session(self, mcp_session_id: str) -> Optional[Credentials]:
|
||||
def get_credentials_by_mcp_session(
|
||||
self, mcp_session_id: str
|
||||
) -> Optional[Credentials]:
|
||||
"""
|
||||
Get Google credentials using FastMCP session ID.
|
||||
|
||||
@@ -407,7 +431,7 @@ class OAuth21SessionStore:
|
||||
requested_user_email: str,
|
||||
session_id: Optional[str] = None,
|
||||
auth_token_email: Optional[str] = None,
|
||||
allow_recent_auth: bool = False
|
||||
allow_recent_auth: bool = False,
|
||||
) -> Optional[Credentials]:
|
||||
"""
|
||||
Get Google credentials with session validation.
|
||||
@@ -466,6 +490,7 @@ class OAuth21SessionStore:
|
||||
# Check transport mode to ensure this is only used in stdio
|
||||
try:
|
||||
from core.config import get_transport_mode
|
||||
|
||||
transport_mode = get_transport_mode()
|
||||
if transport_mode != "stdio":
|
||||
logger.error(
|
||||
@@ -533,7 +558,9 @@ class OAuth21SessionStore:
|
||||
# Also remove from auth binding
|
||||
if mcp_session_id in self._session_auth_binding:
|
||||
del self._session_auth_binding[mcp_session_id]
|
||||
logger.info(f"Removed OAuth 2.1 session for {user_email} and MCP mapping for {mcp_session_id}")
|
||||
logger.info(
|
||||
f"Removed OAuth 2.1 session for {user_email} and MCP mapping for {mcp_session_id}"
|
||||
)
|
||||
|
||||
# Remove OAuth session binding if exists
|
||||
if session_id and session_id in self._session_auth_binding:
|
||||
@@ -612,7 +639,9 @@ def _resolve_client_credentials() -> Tuple[Optional[str], Optional[str]]:
|
||||
try:
|
||||
client_secret = secret_obj.get_secret_value() # type: ignore[call-arg]
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug(f"Failed to resolve client secret from provider: {exc}")
|
||||
logger.debug(
|
||||
f"Failed to resolve client secret from provider: {exc}"
|
||||
)
|
||||
elif isinstance(secret_obj, str):
|
||||
client_secret = secret_obj
|
||||
|
||||
@@ -629,7 +658,9 @@ def _resolve_client_credentials() -> Tuple[Optional[str], Optional[str]]:
|
||||
return client_id, client_secret
|
||||
|
||||
|
||||
def _build_credentials_from_provider(access_token: AccessToken) -> Optional[Credentials]:
|
||||
def _build_credentials_from_provider(
|
||||
access_token: AccessToken,
|
||||
) -> Optional[Credentials]:
|
||||
"""Construct Google credentials from the provider cache."""
|
||||
if not _auth_provider:
|
||||
return None
|
||||
@@ -640,10 +671,14 @@ def _build_credentials_from_provider(access_token: AccessToken) -> Optional[Cred
|
||||
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
|
||||
refresh_token_value = getattr(_auth_provider, "_access_to_refresh", {}).get(access_token.token)
|
||||
refresh_token_value = getattr(_auth_provider, "_access_to_refresh", {}).get(
|
||||
access_token.token
|
||||
)
|
||||
refresh_token_obj = None
|
||||
if refresh_token_value:
|
||||
refresh_token_obj = getattr(_auth_provider, "_refresh_tokens", {}).get(refresh_token_value)
|
||||
refresh_token_obj = getattr(_auth_provider, "_refresh_tokens", {}).get(
|
||||
refresh_token_value
|
||||
)
|
||||
|
||||
expiry = None
|
||||
expires_at = getattr(access_entry, "expires_at", None)
|
||||
@@ -730,7 +765,9 @@ def ensure_session_from_access_token(
|
||||
return credentials
|
||||
|
||||
|
||||
def get_credentials_from_token(access_token: str, user_email: Optional[str] = None) -> Optional[Credentials]:
|
||||
def get_credentials_from_token(
|
||||
access_token: str, user_email: Optional[str] = None
|
||||
) -> Optional[Credentials]:
|
||||
"""
|
||||
Convert a bearer token to Google credentials.
|
||||
|
||||
@@ -753,14 +790,18 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
|
||||
# If the FastMCP provider is managing tokens, sync from provider storage
|
||||
if _auth_provider:
|
||||
access_record = getattr(_auth_provider, "_access_tokens", {}).get(access_token)
|
||||
access_record = getattr(_auth_provider, "_access_tokens", {}).get(
|
||||
access_token
|
||||
)
|
||||
if access_record:
|
||||
logger.debug("Building credentials from FastMCP provider cache")
|
||||
return ensure_session_from_access_token(access_record, user_email)
|
||||
|
||||
# Otherwise, create minimal credentials with just the access token
|
||||
# Assume token is valid for 1 hour (typical for Google tokens)
|
||||
expiry = _normalize_expiry_to_naive_utc(datetime.now(timezone.utc) + timedelta(hours=1))
|
||||
expiry = _normalize_expiry_to_naive_utc(
|
||||
datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
)
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
|
||||
credentials = Credentials(
|
||||
@@ -770,7 +811,7 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scopes=None,
|
||||
expiry=expiry
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
logger.debug("Created fallback Google credentials from bearer token")
|
||||
@@ -781,7 +822,9 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
return None
|
||||
|
||||
|
||||
def store_token_session(token_response: dict, user_email: str, mcp_session_id: Optional[str] = None) -> str:
|
||||
def store_token_session(
|
||||
token_response: dict, user_email: str, mcp_session_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Store a token response in the session store.
|
||||
|
||||
@@ -802,9 +845,12 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
|
||||
if not mcp_session_id:
|
||||
try:
|
||||
from core.context import get_fastmcp_session_id
|
||||
|
||||
mcp_session_id = get_fastmcp_session_id()
|
||||
if mcp_session_id:
|
||||
logger.debug(f"Got FastMCP session ID from context: {mcp_session_id}")
|
||||
logger.debug(
|
||||
f"Got FastMCP session ID from context: {mcp_session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get FastMCP session from context: {e}")
|
||||
|
||||
@@ -815,7 +861,9 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
scopes = token_response.get("scope", "")
|
||||
scopes_list = scopes.split() if scopes else None
|
||||
expiry = datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600))
|
||||
expiry = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=token_response.get("expires_in", 3600)
|
||||
)
|
||||
|
||||
store.store_session(
|
||||
user_email=user_email,
|
||||
@@ -832,7 +880,9 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
|
||||
)
|
||||
|
||||
if mcp_session_id:
|
||||
logger.info(f"Stored token session for {user_email} with MCP session {mcp_session_id}")
|
||||
logger.info(
|
||||
f"Stored token session for {user_email} with MCP session {mcp_session_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Stored token session for {user_email}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user