oauth2.1 truly works

This commit is contained in:
Taylor Wilsdon
2025-08-02 18:25:08 -04:00
parent 9470a41dde
commit c45bb3956c
7 changed files with 233 additions and 17 deletions

View File

@@ -540,7 +540,7 @@ def get_credentials(
session_id: Optional[str] = None,
) -> Optional[Credentials]:
"""
Retrieves stored credentials, prioritizing session, then file. Refreshes if necessary.
Retrieves stored credentials, prioritizing OAuth 2.1 store, then session, then file. Refreshes if necessary.
If credentials are loaded from file and a session_id is present, they are cached in the session.
In single-user mode, bypasses session mapping and uses any available credentials.
@@ -554,6 +554,52 @@ def get_credentials(
Returns:
Valid Credentials object or None.
"""
# First, try OAuth 2.1 session store if we have a session_id (FastMCP session)
if session_id:
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
# Try to get credentials by MCP session
credentials = store.get_credentials_by_mcp_session(session_id)
if credentials:
logger.info(f"[get_credentials] Found OAuth 2.1 credentials for MCP session {session_id}")
# Check scopes
if not all(scope in credentials.scopes for scope in required_scopes):
logger.warning(
f"[get_credentials] OAuth 2.1 credentials lack required scopes. Need: {required_scopes}, Have: {credentials.scopes}"
)
return None
# Return if valid
if credentials.valid:
return credentials
elif credentials.expired and credentials.refresh_token:
# Try to refresh
try:
credentials.refresh(Request())
logger.info(f"[get_credentials] Refreshed OAuth 2.1 credentials for session {session_id}")
# Update stored credentials
user_email = store.get_user_by_mcp_session(session_id)
if user_email:
store.store_session(
user_email=user_email,
access_token=credentials.token,
refresh_token=credentials.refresh_token,
scopes=credentials.scopes,
expiry=credentials.expiry,
mcp_session_id=session_id
)
return credentials
except Exception as e:
logger.error(f"[get_credentials] Failed to refresh OAuth 2.1 credentials: {e}")
return None
except ImportError:
pass # OAuth 2.1 store not available
except Exception as e:
logger.debug(f"[get_credentials] Error checking OAuth 2.1 store: {e}")
# Check for single-user mode
if os.getenv("MCP_SINGLE_USER_MODE") == "1":
logger.info(
@@ -722,6 +768,7 @@ async def get_authenticated_google_service(
tool_name: str, # For logging/debugging
user_google_email: str, # Required - no more Optional
required_scopes: List[str],
session_id: Optional[str] = None, # Session context for logging
) -> tuple[Any, str]:
"""
Centralized Google service authentication for all MCP tools.
@@ -740,8 +787,30 @@ async def get_authenticated_google_service(
Raises:
GoogleAuthenticationError: When authentication is required or fails
"""
# Try to get FastMCP session ID if not provided
if not session_id:
try:
# First try context variable (works in async context)
from core.context import get_fastmcp_session_id
session_id = get_fastmcp_session_id()
if session_id:
logger.debug(f"[{tool_name}] Got FastMCP session ID from context: {session_id}")
except Exception as e:
logger.debug(f"[{tool_name}] Could not get FastMCP session from context: {e}")
# Fallback to direct FastMCP context if context variable not set
if not session_id:
try:
from fastmcp.server.dependencies import get_context
fastmcp_ctx = get_context()
session_id = fastmcp_ctx.session_id
logger.debug(f"[{tool_name}] Got FastMCP session ID directly: {session_id}")
except Exception as e:
logger.debug(f"[{tool_name}] Could not get FastMCP context directly: {e}")
logger.info(
f"[{tool_name}] Attempting to get authenticated {service_name} service. Email: '{user_google_email}'"
f"[{tool_name}] Attempting to get authenticated {service_name} service. Email: '{user_google_email}', Session: '{session_id}'"
)
# Validate email format
@@ -755,7 +824,7 @@ async def get_authenticated_google_service(
user_google_email=user_google_email,
required_scopes=required_scopes,
client_secrets_path=CONFIG_CLIENT_SECRETS_PATH,
session_id=None, # Session ID not available in service layer
session_id=session_id, # Pass through session context
)
if not credentials or not credentials.valid: