fully working in all transport modes and fallbacks!
This commit is contained in:
@@ -20,34 +20,34 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class OAuth21GoogleServiceBuilder:
|
||||
"""Builds Google services using FastMCP OAuth authenticated sessions."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the service builder.
|
||||
"""
|
||||
self._service_cache: Dict[str, Tuple[Any, str]] = {}
|
||||
|
||||
|
||||
def extract_session_from_context(self, context: Optional[Dict[str, Any]] = None) -> Optional[str]:
|
||||
"""
|
||||
Extract session ID from various context sources.
|
||||
|
||||
|
||||
Args:
|
||||
context: Context dictionary that may contain session information
|
||||
|
||||
|
||||
Returns:
|
||||
Session ID if found, None otherwise
|
||||
"""
|
||||
if not context:
|
||||
return None
|
||||
|
||||
|
||||
# Try to extract from OAuth 2.1 auth context
|
||||
if "auth_context" in context and hasattr(context["auth_context"], "session_id"):
|
||||
return context["auth_context"].session_id
|
||||
|
||||
|
||||
# Try direct session_id
|
||||
if "session_id" in context:
|
||||
return context["session_id"]
|
||||
|
||||
|
||||
# Try from request state
|
||||
if "request" in context:
|
||||
request = context["request"]
|
||||
@@ -55,9 +55,9 @@ class OAuth21GoogleServiceBuilder:
|
||||
auth_ctx = request.state.auth
|
||||
if hasattr(auth_ctx, "session_id"):
|
||||
return auth_ctx.session_id
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_authenticated_service_with_session(
|
||||
self,
|
||||
service_name: str,
|
||||
@@ -70,7 +70,7 @@ class OAuth21GoogleServiceBuilder:
|
||||
) -> Tuple[Any, str]:
|
||||
"""
|
||||
Get authenticated Google service using OAuth 2.1 session if available.
|
||||
|
||||
|
||||
Args:
|
||||
service_name: Google service name (e.g., "gmail", "drive")
|
||||
version: API version (e.g., "v1", "v3")
|
||||
@@ -79,59 +79,62 @@ class OAuth21GoogleServiceBuilder:
|
||||
required_scopes: Required OAuth scopes
|
||||
session_id: OAuth 2.1 session ID
|
||||
auth_context: OAuth 2.1 authentication context
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (service instance, actual user email)
|
||||
|
||||
|
||||
Raises:
|
||||
GoogleAuthenticationError: If authentication fails
|
||||
"""
|
||||
cache_key = f"{user_google_email}:{service_name}:{version}:{':'.join(sorted(required_scopes))}"
|
||||
|
||||
|
||||
# Check cache first
|
||||
if cache_key in self._service_cache:
|
||||
logger.debug(f"[{tool_name}] Using cached service for {user_google_email}")
|
||||
return self._service_cache[cache_key]
|
||||
|
||||
|
||||
try:
|
||||
# First check the global OAuth 2.1 session store
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
store = get_oauth21_session_store()
|
||||
credentials = store.get_credentials(user_google_email)
|
||||
|
||||
|
||||
if credentials and credentials.valid:
|
||||
logger.info(f"[{tool_name}] Found OAuth 2.1 credentials in global store for {user_google_email}")
|
||||
|
||||
|
||||
# Build the service
|
||||
service = await asyncio.to_thread(
|
||||
build, service_name, version, credentials=credentials
|
||||
)
|
||||
|
||||
|
||||
# Cache the service
|
||||
self._service_cache[cache_key] = (service, user_google_email)
|
||||
|
||||
|
||||
return service, user_google_email
|
||||
|
||||
# OAuth 2.1 is now handled by FastMCP - removed legacy auth_layer code
|
||||
|
||||
# Fall back to legacy authentication
|
||||
logger.debug(f"[{tool_name}] Falling back to legacy authentication for {user_google_email}")
|
||||
from auth.google_auth import get_authenticated_google_service as legacy_get_service
|
||||
|
||||
return await legacy_get_service(
|
||||
service_name=service_name,
|
||||
version=version,
|
||||
tool_name=tool_name,
|
||||
user_google_email=user_google_email,
|
||||
required_scopes=required_scopes,
|
||||
|
||||
# If OAuth 2.1 is not enabled, fall back to legacy authentication
|
||||
if not is_oauth21_enabled():
|
||||
logger.debug(f"[{tool_name}] OAuth 2.1 is not enabled. Falling back to legacy authentication for {user_google_email}")
|
||||
return await get_legacy_auth_service(
|
||||
service_name=service_name,
|
||||
version=version,
|
||||
tool_name=tool_name,
|
||||
user_google_email=user_google_email,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
|
||||
# If we are here, it means OAuth 2.1 is enabled but credentials are not found
|
||||
logger.error(f"[{tool_name}] OAuth 2.1 is enabled, but no valid credentials found for {user_google_email}")
|
||||
raise GoogleAuthenticationError(
|
||||
f"OAuth 2.1 is enabled, but no valid credentials found for {user_google_email}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{tool_name}] Authentication failed for {user_google_email}: {e}")
|
||||
raise GoogleAuthenticationError(
|
||||
f"Failed to authenticate for {service_name}: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the service cache."""
|
||||
self._service_cache.clear()
|
||||
@@ -157,6 +160,46 @@ def set_auth_layer(auth_layer):
|
||||
logger.info("set_auth_layer called - OAuth is now handled by FastMCP")
|
||||
|
||||
|
||||
_oauth21_enabled = False
|
||||
|
||||
def is_oauth21_enabled() -> bool:
|
||||
"""
|
||||
Check if the OAuth 2.1 authentication layer is active.
|
||||
"""
|
||||
global _oauth21_enabled
|
||||
return _oauth21_enabled
|
||||
|
||||
|
||||
def enable_oauth21():
|
||||
"""
|
||||
Enable the OAuth 2.1 authentication layer.
|
||||
"""
|
||||
global _oauth21_enabled
|
||||
_oauth21_enabled = True
|
||||
logger.info("OAuth 2.1 authentication has been enabled.")
|
||||
|
||||
|
||||
async def get_legacy_auth_service(
|
||||
service_name: str,
|
||||
version: str,
|
||||
tool_name: str,
|
||||
user_google_email: str,
|
||||
required_scopes: list[str],
|
||||
) -> Tuple[Any, str]:
|
||||
"""
|
||||
Get authenticated Google service using legacy authentication.
|
||||
"""
|
||||
from auth.google_auth import get_authenticated_google_service as legacy_get_service
|
||||
|
||||
return await legacy_get_service(
|
||||
service_name=service_name,
|
||||
version=version,
|
||||
tool_name=tool_name,
|
||||
user_google_email=user_google_email,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
|
||||
|
||||
async def get_authenticated_google_service_oauth21(
|
||||
service_name: str,
|
||||
version: str,
|
||||
@@ -167,10 +210,10 @@ async def get_authenticated_google_service_oauth21(
|
||||
) -> Tuple[Any, str]:
|
||||
"""
|
||||
Enhanced version of get_authenticated_google_service that supports OAuth 2.1.
|
||||
|
||||
|
||||
This function checks for OAuth 2.1 session context and uses it if available,
|
||||
otherwise falls back to legacy authentication.
|
||||
|
||||
|
||||
Args:
|
||||
service_name: Google service name
|
||||
version: API version
|
||||
@@ -178,20 +221,20 @@ async def get_authenticated_google_service_oauth21(
|
||||
user_google_email: User's Google email
|
||||
required_scopes: Required OAuth scopes
|
||||
context: Optional context containing session information
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (service instance, actual user email)
|
||||
"""
|
||||
builder = get_oauth21_service_builder()
|
||||
|
||||
|
||||
# FastMCP handles context now - extract any session info
|
||||
session_id = None
|
||||
auth_context = None
|
||||
|
||||
|
||||
if context:
|
||||
session_id = builder.extract_session_from_context(context)
|
||||
auth_context = context.get("auth_context")
|
||||
|
||||
|
||||
return await builder.get_authenticated_service_with_session(
|
||||
service_name=service_name,
|
||||
version=version,
|
||||
|
||||
@@ -41,9 +41,9 @@ async def get_authenticated_google_service_oauth21(
|
||||
"""
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
|
||||
# Use the new validation method to ensure session can only access its own credentials
|
||||
credentials = store.get_credentials_with_validation(
|
||||
requested_user_email=user_google_email,
|
||||
@@ -51,23 +51,23 @@ async def get_authenticated_google_service_oauth21(
|
||||
auth_token_email=auth_token_email,
|
||||
allow_recent_auth=allow_recent_auth
|
||||
)
|
||||
|
||||
|
||||
if not credentials:
|
||||
from auth.google_auth import GoogleAuthenticationError
|
||||
raise GoogleAuthenticationError(
|
||||
f"Access denied: Cannot retrieve credentials for {user_google_email}. "
|
||||
f"You can only access credentials for your authenticated account."
|
||||
)
|
||||
|
||||
|
||||
# Check scopes
|
||||
if not all(scope in credentials.scopes for scope in required_scopes):
|
||||
from auth.google_auth import GoogleAuthenticationError
|
||||
raise GoogleAuthenticationError(f"OAuth 2.1 credentials lack required scopes. Need: {required_scopes}, Have: {credentials.scopes}")
|
||||
|
||||
|
||||
# Build service
|
||||
service = build(service_name, version, credentials=credentials)
|
||||
logger.info(f"[{tool_name}] Successfully authenticated {service_name} service using OAuth 2.1 for user: {user_google_email}")
|
||||
|
||||
|
||||
return service, user_google_email
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -129,7 +129,7 @@ SCOPE_GROUPS = {
|
||||
# Tasks scopes
|
||||
"tasks": TASKS_SCOPE,
|
||||
"tasks_read": TASKS_READONLY_SCOPE,
|
||||
|
||||
|
||||
# Custom Search scope
|
||||
"customsearch": CUSTOM_SEARCH_SCOPE,
|
||||
}
|
||||
@@ -307,13 +307,13 @@ def require_google_service(
|
||||
if service is None:
|
||||
try:
|
||||
tool_name = func.__name__
|
||||
|
||||
|
||||
# SIMPLIFIED: Just get the authenticated user from the context
|
||||
# The AuthInfoMiddleware has already done all the authentication checks
|
||||
authenticated_user = None
|
||||
auth_method = None
|
||||
mcp_session_id = None
|
||||
|
||||
|
||||
try:
|
||||
from fastmcp.server.dependencies import get_context
|
||||
ctx = get_context()
|
||||
@@ -321,82 +321,31 @@ def require_google_service(
|
||||
# Get the authenticated user email set by AuthInfoMiddleware
|
||||
authenticated_user = ctx.get_state("authenticated_user_email")
|
||||
auth_method = ctx.get_state("authenticated_via")
|
||||
|
||||
|
||||
# Get session ID for logging
|
||||
if hasattr(ctx, 'session_id'):
|
||||
mcp_session_id = ctx.session_id
|
||||
# Set FastMCP session ID in context variable for propagation
|
||||
from core.context import set_fastmcp_session_id
|
||||
set_fastmcp_session_id(mcp_session_id)
|
||||
|
||||
|
||||
logger.info(f"[{tool_name}] Authentication from middleware: user={authenticated_user}, method={auth_method}")
|
||||
except Exception as e:
|
||||
logger.debug(f"[{tool_name}] Could not get FastMCP context: {e}")
|
||||
|
||||
|
||||
# Log authentication status
|
||||
logger.info(f"[{tool_name}] Authentication Status:"
|
||||
f" Method={auth_method or 'none'},"
|
||||
f" User={authenticated_user or 'none'},"
|
||||
f" MCPSessionID={mcp_session_id or 'none'}")
|
||||
|
||||
# SIMPLIFIED: Check transport mode and authentication state
|
||||
from core.config import get_transport_mode
|
||||
transport_mode = get_transport_mode()
|
||||
|
||||
# Check if OAuth 2.1 provider is configured
|
||||
oauth21_enabled = False
|
||||
try:
|
||||
from core.server import get_auth_provider
|
||||
auth_provider = get_auth_provider()
|
||||
oauth21_enabled = auth_provider is not None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Determine if we should proceed based on authentication state
|
||||
if transport_mode == "streamable-http" and oauth21_enabled:
|
||||
# OAuth 2.1 mode - require authentication
|
||||
if not authenticated_user:
|
||||
logger.error(f"[{tool_name}] SECURITY: Unauthenticated request denied in OAuth 2.1 mode")
|
||||
raise Exception(
|
||||
"Authentication required. This server is configured with OAuth 2.1 authentication. "
|
||||
"Please authenticate first using the OAuth flow before accessing resources."
|
||||
)
|
||||
|
||||
# Additional security: Verify the authenticated user matches the requested user
|
||||
if authenticated_user != user_google_email:
|
||||
logger.warning(
|
||||
f"[{tool_name}] User mismatch - authenticated as {authenticated_user} "
|
||||
f"but requesting resources for {user_google_email}"
|
||||
)
|
||||
# The OAuth21SessionStore will handle the actual validation
|
||||
|
||||
# Determine authentication method to use
|
||||
# Use OAuth 2.1 only if:
|
||||
# 1. OAuth 2.1 mode is active (oauth21_enabled), OR
|
||||
# 2. We have an authenticated user (from bearer token), OR
|
||||
# 3. In stdio mode, the user has a session in the OAuth21 store
|
||||
use_oauth21 = False
|
||||
if oauth21_enabled and (transport_mode == "streamable-http" or authenticated_user):
|
||||
# OAuth 2.1 mode is active or we have bearer token auth
|
||||
use_oauth21 = True
|
||||
elif transport_mode == "stdio" and OAUTH21_INTEGRATION_AVAILABLE:
|
||||
# In stdio mode, check if user has OAuth 2.1 credentials stored
|
||||
try:
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
store = get_oauth21_session_store()
|
||||
if store.has_session(user_google_email):
|
||||
use_oauth21 = True
|
||||
logger.debug(f"[{tool_name}] User has OAuth 2.1 session in stdio mode")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if use_oauth21:
|
||||
# Use OAuth 2.1 authentication
|
||||
logger.info(f"[{tool_name}] Using OAuth 2.1 authentication")
|
||||
|
||||
# Determine if we should allow recent auth (ONLY in stdio mode)
|
||||
allow_recent_auth = (transport_mode == "stdio" and not authenticated_user)
|
||||
|
||||
|
||||
from auth.oauth21_integration import is_oauth21_enabled
|
||||
|
||||
if is_oauth21_enabled():
|
||||
logger.debug(f"[{tool_name}] Attempting OAuth 2.1 authentication flow.")
|
||||
# The downstream get_authenticated_google_service_oauth21 will handle
|
||||
# whether the user's token is valid for the requested resource.
|
||||
# This decorator should not block the call here.
|
||||
service, actual_user_email = await get_authenticated_google_service_oauth21(
|
||||
service_name=service_name,
|
||||
version=service_version,
|
||||
@@ -405,11 +354,11 @@ def require_google_service(
|
||||
required_scopes=resolved_scopes,
|
||||
session_id=mcp_session_id,
|
||||
auth_token_email=authenticated_user,
|
||||
allow_recent_auth=allow_recent_auth,
|
||||
allow_recent_auth=False,
|
||||
)
|
||||
elif transport_mode == "stdio":
|
||||
# Fall back to legacy authentication in stdio mode
|
||||
logger.info(f"[{tool_name}] Using legacy authentication (stdio mode)")
|
||||
else:
|
||||
# If OAuth 2.1 is not enabled, always use the legacy authentication method.
|
||||
logger.debug(f"[{tool_name}] Using legacy authentication flow (OAuth 2.1 disabled).")
|
||||
service, actual_user_email = await get_authenticated_google_service(
|
||||
service_name=service_name,
|
||||
version=service_version,
|
||||
@@ -418,10 +367,7 @@ def require_google_service(
|
||||
required_scopes=resolved_scopes,
|
||||
session_id=mcp_session_id,
|
||||
)
|
||||
else:
|
||||
logger.error(f"[{tool_name}] No authentication available in {transport_mode} mode")
|
||||
raise Exception(f"Authentication not available in {transport_mode} mode")
|
||||
|
||||
|
||||
if cache_enabled:
|
||||
cache_key = _get_cache_key(user_google_email, service_name, service_version, resolved_scopes)
|
||||
_cache_service(cache_key, service, actual_user_email)
|
||||
@@ -500,12 +446,12 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
|
||||
|
||||
try:
|
||||
tool_name = func.__name__
|
||||
|
||||
|
||||
# SIMPLIFIED: Get authentication state from context (set by AuthInfoMiddleware)
|
||||
authenticated_user = None
|
||||
auth_method = None
|
||||
mcp_session_id = None
|
||||
|
||||
|
||||
try:
|
||||
from fastmcp.server.dependencies import get_context
|
||||
ctx = get_context()
|
||||
@@ -516,47 +462,12 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
|
||||
mcp_session_id = ctx.session_id
|
||||
except Exception as e:
|
||||
logger.debug(f"[{tool_name}] Could not get FastMCP context: {e}")
|
||||
|
||||
# Check transport mode and OAuth 2.1 configuration
|
||||
from core.config import get_transport_mode
|
||||
transport_mode = get_transport_mode()
|
||||
|
||||
oauth21_enabled = False
|
||||
try:
|
||||
from core.server import get_auth_provider
|
||||
auth_provider = get_auth_provider()
|
||||
oauth21_enabled = auth_provider is not None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# In OAuth 2.1 mode, require authentication
|
||||
if transport_mode == "streamable-http" and oauth21_enabled:
|
||||
if not authenticated_user:
|
||||
logger.error(f"[{tool_name}] SECURITY: Unauthenticated request denied in OAuth 2.1 mode")
|
||||
raise Exception(
|
||||
"Authentication required. This server is configured with OAuth 2.1 authentication. "
|
||||
"Please authenticate first using the OAuth flow before accessing resources."
|
||||
)
|
||||
|
||||
# Determine authentication method to use (same logic as single service)
|
||||
use_oauth21 = False
|
||||
if oauth21_enabled and (transport_mode == "streamable-http" or authenticated_user):
|
||||
use_oauth21 = True
|
||||
elif transport_mode == "stdio" and OAUTH21_INTEGRATION_AVAILABLE:
|
||||
try:
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
store = get_oauth21_session_store()
|
||||
if store.has_session(user_google_email):
|
||||
use_oauth21 = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if use_oauth21:
|
||||
logger.debug(f"OAuth 2.1 authentication for {tool_name} ({service_type})")
|
||||
|
||||
# Determine if we should allow recent auth (ONLY in stdio mode)
|
||||
allow_recent_auth = (transport_mode == "stdio" and not authenticated_user)
|
||||
|
||||
|
||||
# Use the same logic as single service decorator
|
||||
from auth.oauth21_integration import is_oauth21_enabled
|
||||
|
||||
if is_oauth21_enabled():
|
||||
logger.debug(f"[{tool_name}] Attempting OAuth 2.1 authentication flow for {service_type}.")
|
||||
service, _ = await get_authenticated_google_service_oauth21(
|
||||
service_name=service_name,
|
||||
version=service_version,
|
||||
@@ -565,10 +476,11 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
|
||||
required_scopes=resolved_scopes,
|
||||
session_id=mcp_session_id,
|
||||
auth_token_email=authenticated_user,
|
||||
allow_recent_auth=allow_recent_auth,
|
||||
allow_recent_auth=False,
|
||||
)
|
||||
elif transport_mode == "stdio":
|
||||
# Fall back to legacy authentication ONLY in stdio mode
|
||||
else:
|
||||
# If OAuth 2.1 is not enabled, always use the legacy authentication method.
|
||||
logger.debug(f"[{tool_name}] Using legacy authentication flow for {service_type} (OAuth 2.1 disabled).")
|
||||
service, _ = await get_authenticated_google_service(
|
||||
service_name=service_name,
|
||||
version=service_version,
|
||||
@@ -577,9 +489,6 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
|
||||
required_scopes=resolved_scopes,
|
||||
session_id=mcp_session_id,
|
||||
)
|
||||
else:
|
||||
logger.error(f"[{tool_name}] No authentication available in {transport_mode} mode")
|
||||
raise Exception(f"Authentication not available in {transport_mode} mode")
|
||||
|
||||
# Inject service with specified parameter name
|
||||
kwargs[param_name] = service
|
||||
|
||||
Reference in New Issue
Block a user