apply ruff formatting

This commit is contained in:
Taylor Wilsdon
2025-12-13 13:49:28 -08:00
parent 1d80a24ca4
commit 6b8352a354
50 changed files with 4010 additions and 2842 deletions

View File

@@ -34,7 +34,9 @@ def _normalize_expiry_to_naive_utc(expiry: Optional[Any]) -> Optional[datetime]:
try:
return expiry.astimezone(timezone.utc).replace(tzinfo=None)
except Exception: # pragma: no cover - defensive
logger.debug("Failed to normalize aware expiry; returning without tzinfo")
logger.debug(
"Failed to normalize aware expiry; returning without tzinfo"
)
return expiry.replace(tzinfo=None)
return expiry # Already naive; assumed to represent UTC
@@ -51,15 +53,15 @@ def _normalize_expiry_to_naive_utc(expiry: Optional[Any]) -> Optional[datetime]:
# Context variable to store the current session information
_current_session_context: contextvars.ContextVar[Optional['SessionContext']] = contextvars.ContextVar(
'current_session_context',
default=None
_current_session_context: contextvars.ContextVar[Optional["SessionContext"]] = (
contextvars.ContextVar("current_session_context", default=None)
)
@dataclass
class SessionContext:
"""Container for session-related information."""
session_id: Optional[str] = None
user_id: Optional[str] = None
auth_context: Optional[Any] = None
@@ -81,7 +83,9 @@ def set_session_context(context: Optional[SessionContext]):
"""
_current_session_context.set(context)
if context:
logger.debug(f"Set session context: session_id={context.session_id}, user_id={context.user_id}")
logger.debug(
f"Set session context: session_id={context.session_id}, user_id={context.user_id}"
)
else:
logger.debug("Cleared session context")
@@ -161,6 +165,7 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
# If no session found, create a temporary session ID from token hash
# This allows header-based authentication to work with session context
import hashlib
token_hash = hashlib.sha256(token.encode()).hexdigest()[:8]
return f"bearer_token_{token_hash}"
@@ -171,6 +176,7 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
# OAuth21SessionStore - Main Session Management
# =============================================================================
class OAuth21SessionStore:
"""
Global store for OAuth 2.1 authenticated sessions.
@@ -185,8 +191,12 @@ class OAuth21SessionStore:
def __init__(self):
self._sessions: Dict[str, Dict[str, Any]] = {}
self._mcp_session_mapping: Dict[str, str] = {} # Maps FastMCP session ID -> user email
self._session_auth_binding: Dict[str, str] = {} # Maps session ID -> authenticated user email (immutable)
self._mcp_session_mapping: Dict[
str, str
] = {} # Maps FastMCP session ID -> user email
self._session_auth_binding: Dict[
str, str
] = {} # Maps session ID -> authenticated user email (immutable)
self._oauth_states: Dict[str, Dict[str, Any]] = {}
self._lock = RLock()
@@ -258,7 +268,9 @@ class OAuth21SessionStore:
state_info = self._oauth_states.get(state)
if not state_info:
logger.error("SECURITY: OAuth callback received unknown or expired state")
logger.error(
"SECURITY: OAuth callback received unknown or expired state"
)
raise ValueError("Invalid or expired OAuth state parameter")
bound_session = state_info.get("session_id")
@@ -332,16 +344,26 @@ class OAuth21SessionStore:
# Create immutable session binding (first binding wins, cannot be changed)
if mcp_session_id not in self._session_auth_binding:
self._session_auth_binding[mcp_session_id] = user_email
logger.info(f"Created immutable session binding: {mcp_session_id} -> {user_email}")
logger.info(
f"Created immutable session binding: {mcp_session_id} -> {user_email}"
)
elif self._session_auth_binding[mcp_session_id] != user_email:
# Security: Attempt to bind session to different user
logger.error(f"SECURITY: Attempt to rebind session {mcp_session_id} from {self._session_auth_binding[mcp_session_id]} to {user_email}")
raise ValueError(f"Session {mcp_session_id} is already bound to a different user")
logger.error(
f"SECURITY: Attempt to rebind session {mcp_session_id} from {self._session_auth_binding[mcp_session_id]} to {user_email}"
)
raise ValueError(
f"Session {mcp_session_id} is already bound to a different user"
)
self._mcp_session_mapping[mcp_session_id] = user_email
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id}, mcp_session_id: {mcp_session_id})")
logger.info(
f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id}, mcp_session_id: {mcp_session_id})"
)
else:
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id})")
logger.info(
f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id})"
)
# Also create binding for the OAuth session ID
if session_id and session_id not in self._session_auth_binding:
@@ -382,7 +404,9 @@ class OAuth21SessionStore:
logger.error(f"Failed to create credentials for {user_email}: {e}")
return None
def get_credentials_by_mcp_session(self, mcp_session_id: str) -> Optional[Credentials]:
def get_credentials_by_mcp_session(
self, mcp_session_id: str
) -> Optional[Credentials]:
"""
Get Google credentials using FastMCP session ID.
@@ -407,7 +431,7 @@ class OAuth21SessionStore:
requested_user_email: str,
session_id: Optional[str] = None,
auth_token_email: Optional[str] = None,
allow_recent_auth: bool = False
allow_recent_auth: bool = False,
) -> Optional[Credentials]:
"""
Get Google credentials with session validation.
@@ -466,6 +490,7 @@ class OAuth21SessionStore:
# Check transport mode to ensure this is only used in stdio
try:
from core.config import get_transport_mode
transport_mode = get_transport_mode()
if transport_mode != "stdio":
logger.error(
@@ -533,7 +558,9 @@ class OAuth21SessionStore:
# Also remove from auth binding
if mcp_session_id in self._session_auth_binding:
del self._session_auth_binding[mcp_session_id]
logger.info(f"Removed OAuth 2.1 session for {user_email} and MCP mapping for {mcp_session_id}")
logger.info(
f"Removed OAuth 2.1 session for {user_email} and MCP mapping for {mcp_session_id}"
)
# Remove OAuth session binding if exists
if session_id and session_id in self._session_auth_binding:
@@ -612,7 +639,9 @@ def _resolve_client_credentials() -> Tuple[Optional[str], Optional[str]]:
try:
client_secret = secret_obj.get_secret_value() # type: ignore[call-arg]
except Exception as exc: # pragma: no cover - defensive
logger.debug(f"Failed to resolve client secret from provider: {exc}")
logger.debug(
f"Failed to resolve client secret from provider: {exc}"
)
elif isinstance(secret_obj, str):
client_secret = secret_obj
@@ -629,7 +658,9 @@ def _resolve_client_credentials() -> Tuple[Optional[str], Optional[str]]:
return client_id, client_secret
def _build_credentials_from_provider(access_token: AccessToken) -> Optional[Credentials]:
def _build_credentials_from_provider(
access_token: AccessToken,
) -> Optional[Credentials]:
"""Construct Google credentials from the provider cache."""
if not _auth_provider:
return None
@@ -640,10 +671,14 @@ def _build_credentials_from_provider(access_token: AccessToken) -> Optional[Cred
client_id, client_secret = _resolve_client_credentials()
refresh_token_value = getattr(_auth_provider, "_access_to_refresh", {}).get(access_token.token)
refresh_token_value = getattr(_auth_provider, "_access_to_refresh", {}).get(
access_token.token
)
refresh_token_obj = None
if refresh_token_value:
refresh_token_obj = getattr(_auth_provider, "_refresh_tokens", {}).get(refresh_token_value)
refresh_token_obj = getattr(_auth_provider, "_refresh_tokens", {}).get(
refresh_token_value
)
expiry = None
expires_at = getattr(access_entry, "expires_at", None)
@@ -730,7 +765,9 @@ def ensure_session_from_access_token(
return credentials
def get_credentials_from_token(access_token: str, user_email: Optional[str] = None) -> Optional[Credentials]:
def get_credentials_from_token(
access_token: str, user_email: Optional[str] = None
) -> Optional[Credentials]:
"""
Convert a bearer token to Google credentials.
@@ -753,14 +790,18 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
# If the FastMCP provider is managing tokens, sync from provider storage
if _auth_provider:
access_record = getattr(_auth_provider, "_access_tokens", {}).get(access_token)
access_record = getattr(_auth_provider, "_access_tokens", {}).get(
access_token
)
if access_record:
logger.debug("Building credentials from FastMCP provider cache")
return ensure_session_from_access_token(access_record, user_email)
# Otherwise, create minimal credentials with just the access token
# Assume token is valid for 1 hour (typical for Google tokens)
expiry = _normalize_expiry_to_naive_utc(datetime.now(timezone.utc) + timedelta(hours=1))
expiry = _normalize_expiry_to_naive_utc(
datetime.now(timezone.utc) + timedelta(hours=1)
)
client_id, client_secret = _resolve_client_credentials()
credentials = Credentials(
@@ -770,7 +811,7 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
client_id=client_id,
client_secret=client_secret,
scopes=None,
expiry=expiry
expiry=expiry,
)
logger.debug("Created fallback Google credentials from bearer token")
@@ -781,7 +822,9 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
return None
def store_token_session(token_response: dict, user_email: str, mcp_session_id: Optional[str] = None) -> str:
def store_token_session(
token_response: dict, user_email: str, mcp_session_id: Optional[str] = None
) -> str:
"""
Store a token response in the session store.
@@ -802,9 +845,12 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
if not mcp_session_id:
try:
from core.context import get_fastmcp_session_id
mcp_session_id = get_fastmcp_session_id()
if mcp_session_id:
logger.debug(f"Got FastMCP session ID from context: {mcp_session_id}")
logger.debug(
f"Got FastMCP session ID from context: {mcp_session_id}"
)
except Exception as e:
logger.debug(f"Could not get FastMCP session from context: {e}")
@@ -815,7 +861,9 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
client_id, client_secret = _resolve_client_credentials()
scopes = token_response.get("scope", "")
scopes_list = scopes.split() if scopes else None
expiry = datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600))
expiry = datetime.now(timezone.utc) + timedelta(
seconds=token_response.get("expires_in", 3600)
)
store.store_session(
user_email=user_email,
@@ -832,7 +880,9 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
)
if mcp_session_id:
logger.info(f"Stored token session for {user_email} with MCP session {mcp_session_id}")
logger.info(
f"Stored token session for {user_email} with MCP session {mcp_session_id}"
)
else:
logger.info(f"Stored token session for {user_email}")