refactor oauth2.1 support to fastmcp native
This commit is contained in:
@@ -8,11 +8,12 @@ session context management and credential conversion functionality.
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Dict, Optional, Any, Tuple
|
||||
from threading import RLock
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastmcp.server.auth import AccessToken
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -565,6 +566,131 @@ def get_auth_provider():
|
||||
return _auth_provider
|
||||
|
||||
|
||||
def _resolve_client_credentials() -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Resolve OAuth client credentials from the active provider or configuration."""
|
||||
client_id: Optional[str] = None
|
||||
client_secret: Optional[str] = None
|
||||
|
||||
if _auth_provider:
|
||||
client_id = getattr(_auth_provider, "_upstream_client_id", None)
|
||||
secret_obj = getattr(_auth_provider, "_upstream_client_secret", None)
|
||||
if secret_obj is not None:
|
||||
if hasattr(secret_obj, "get_secret_value"):
|
||||
try:
|
||||
client_secret = secret_obj.get_secret_value() # type: ignore[call-arg]
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug(f"Failed to resolve client secret from provider: {exc}")
|
||||
elif isinstance(secret_obj, str):
|
||||
client_secret = secret_obj
|
||||
|
||||
if not client_id or not client_secret:
|
||||
try:
|
||||
from auth.oauth_config import get_oauth_config
|
||||
|
||||
cfg = get_oauth_config()
|
||||
client_id = client_id or cfg.client_id
|
||||
client_secret = client_secret or cfg.client_secret
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug(f"Failed to resolve client credentials from config: {exc}")
|
||||
|
||||
return client_id, client_secret
|
||||
|
||||
|
||||
def _build_credentials_from_provider(access_token: AccessToken) -> Optional[Credentials]:
|
||||
"""Construct Google credentials from the provider cache."""
|
||||
if not _auth_provider:
|
||||
return None
|
||||
|
||||
access_entry = getattr(_auth_provider, "_access_tokens", {}).get(access_token.token)
|
||||
if not access_entry:
|
||||
access_entry = access_token
|
||||
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
|
||||
refresh_token_value = getattr(_auth_provider, "_access_to_refresh", {}).get(access_token.token)
|
||||
refresh_token_obj = None
|
||||
if refresh_token_value:
|
||||
refresh_token_obj = getattr(_auth_provider, "_refresh_tokens", {}).get(refresh_token_value)
|
||||
|
||||
expiry = None
|
||||
expires_at = getattr(access_entry, "expires_at", None)
|
||||
if expires_at:
|
||||
try:
|
||||
expiry = datetime.utcfromtimestamp(expires_at)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
expiry = None
|
||||
|
||||
scopes = getattr(access_entry, "scopes", None)
|
||||
|
||||
return Credentials(
|
||||
token=access_token.token,
|
||||
refresh_token=refresh_token_obj.token if refresh_token_obj else None,
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scopes=scopes,
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
|
||||
def ensure_session_from_access_token(
|
||||
access_token: AccessToken,
|
||||
user_email: Optional[str],
|
||||
mcp_session_id: Optional[str] = None,
|
||||
) -> Optional[Credentials]:
|
||||
"""Ensure credentials derived from an access token are cached and returned."""
|
||||
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
email = user_email
|
||||
if not email and getattr(access_token, "claims", None):
|
||||
email = access_token.claims.get("email")
|
||||
|
||||
credentials = _build_credentials_from_provider(access_token)
|
||||
|
||||
if credentials is None:
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
expiry = None
|
||||
expires_at = getattr(access_token, "expires_at", None)
|
||||
if expires_at:
|
||||
try:
|
||||
expiry = datetime.utcfromtimestamp(expires_at)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
expiry = None
|
||||
|
||||
credentials = Credentials(
|
||||
token=access_token.token,
|
||||
refresh_token=None,
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scopes=getattr(access_token, "scopes", None),
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
if email:
|
||||
try:
|
||||
store = get_oauth21_session_store()
|
||||
store.store_session(
|
||||
user_email=email,
|
||||
access_token=credentials.token,
|
||||
refresh_token=credentials.refresh_token,
|
||||
token_uri=credentials.token_uri,
|
||||
client_id=credentials.client_id,
|
||||
client_secret=credentials.client_secret,
|
||||
scopes=credentials.scopes,
|
||||
expiry=credentials.expiry,
|
||||
session_id=f"google_{email}",
|
||||
mcp_session_id=mcp_session_id,
|
||||
issuer="https://accounts.google.com",
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug(f"Failed to cache credentials for {email}: {exc}")
|
||||
|
||||
return credentials
|
||||
|
||||
|
||||
def get_credentials_from_token(access_token: str, user_email: Optional[str] = None) -> Optional[Credentials]:
|
||||
"""
|
||||
Convert a bearer token to Google credentials.
|
||||
@@ -576,10 +702,6 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
Returns:
|
||||
Google Credentials object or None
|
||||
"""
|
||||
if not _auth_provider:
|
||||
logger.error("Auth provider not configured")
|
||||
return None
|
||||
|
||||
try:
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
@@ -590,21 +712,29 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
logger.debug(f"Found matching credentials from store for {user_email}")
|
||||
return credentials
|
||||
|
||||
# If the FastMCP provider is managing tokens, sync from provider storage
|
||||
if _auth_provider:
|
||||
access_record = getattr(_auth_provider, "_access_tokens", {}).get(access_token)
|
||||
if access_record:
|
||||
logger.debug("Building credentials from FastMCP provider cache")
|
||||
return ensure_session_from_access_token(access_record, user_email)
|
||||
|
||||
# Otherwise, create minimal credentials with just the access token
|
||||
# Assume token is valid for 1 hour (typical for Google tokens)
|
||||
expiry = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
|
||||
credentials = Credentials(
|
||||
token=access_token,
|
||||
refresh_token=None,
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=_auth_provider.client_id,
|
||||
client_secret=_auth_provider.client_secret,
|
||||
scopes=None, # Will be populated from token claims if available
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scopes=None,
|
||||
expiry=expiry
|
||||
)
|
||||
|
||||
logger.debug("Created Google credentials from bearer token")
|
||||
logger.debug("Created fallback Google credentials from bearer token")
|
||||
return credentials
|
||||
|
||||
except Exception as e:
|
||||
@@ -643,18 +773,23 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
session_id = f"google_{user_email}"
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
scopes = token_response.get("scope", "")
|
||||
scopes_list = scopes.split() if scopes else None
|
||||
expiry = datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600))
|
||||
|
||||
store.store_session(
|
||||
user_email=user_email,
|
||||
access_token=token_response.get("access_token"),
|
||||
refresh_token=token_response.get("refresh_token"),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=_auth_provider.client_id,
|
||||
client_secret=_auth_provider.client_secret,
|
||||
scopes=token_response.get("scope", "").split() if token_response.get("scope") else None,
|
||||
expiry=datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600)),
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scopes=scopes_list,
|
||||
expiry=expiry,
|
||||
session_id=session_id,
|
||||
mcp_session_id=mcp_session_id,
|
||||
issuer="https://accounts.google.com", # Add issuer for Google tokens
|
||||
issuer="https://accounts.google.com",
|
||||
)
|
||||
|
||||
if mcp_session_id:
|
||||
|
||||
Reference in New Issue
Block a user