refactor to centralize, move to desktop type
This commit is contained in:
@@ -160,23 +160,21 @@ def set_auth_layer(auth_layer):
|
||||
logger.info("set_auth_layer called - OAuth is now handled by FastMCP")
|
||||
|
||||
|
||||
_oauth21_enabled = False
|
||||
|
||||
def is_oauth21_enabled() -> bool:
|
||||
"""
|
||||
Check if the OAuth 2.1 authentication layer is active.
|
||||
Uses centralized configuration from oauth_config.
|
||||
"""
|
||||
global _oauth21_enabled
|
||||
return _oauth21_enabled
|
||||
from auth.oauth_config import is_oauth21_enabled as config_oauth21_enabled
|
||||
return config_oauth21_enabled()
|
||||
|
||||
|
||||
def enable_oauth21():
|
||||
"""
|
||||
Enable the OAuth 2.1 authentication layer.
|
||||
Note: This is now controlled by MCP_ENABLE_OAUTH21 env var via oauth_config.
|
||||
"""
|
||||
global _oauth21_enabled
|
||||
_oauth21_enabled = True
|
||||
logger.debug("OAuth 2.1 authentication enabled")
|
||||
logger.debug("OAuth 2.1 authentication enable request - controlled by MCP_ENABLE_OAUTH21 env var")
|
||||
|
||||
|
||||
async def get_legacy_auth_service(
|
||||
@@ -206,13 +204,16 @@ async def get_authenticated_google_service_oauth21(
|
||||
tool_name: str,
|
||||
user_google_email: str,
|
||||
required_scopes: list[str],
|
||||
session_id: Optional[str] = None,
|
||||
auth_token_email: Optional[str] = None,
|
||||
allow_recent_auth: bool = False,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, str]:
|
||||
"""
|
||||
Enhanced version of get_authenticated_google_service that supports OAuth 2.1.
|
||||
|
||||
This function checks for OAuth 2.1 session context and uses it if available,
|
||||
otherwise falls back to legacy authentication.
|
||||
otherwise falls back to legacy authentication based on configuration.
|
||||
|
||||
Args:
|
||||
service_name: Google service name
|
||||
@@ -220,20 +221,32 @@ async def get_authenticated_google_service_oauth21(
|
||||
tool_name: Tool name for logging
|
||||
user_google_email: User's Google email
|
||||
required_scopes: Required OAuth scopes
|
||||
session_id: Optional OAuth session ID
|
||||
auth_token_email: Optional authenticated user email from token
|
||||
allow_recent_auth: Whether to allow recently authenticated sessions
|
||||
context: Optional context containing session information
|
||||
|
||||
Returns:
|
||||
Tuple of (service instance, actual user email)
|
||||
"""
|
||||
# Check if OAuth 2.1 is truly enabled
|
||||
if not is_oauth21_enabled():
|
||||
logger.debug(f"[{tool_name}] OAuth 2.1 disabled, using legacy authentication")
|
||||
return await get_legacy_auth_service(
|
||||
service_name=service_name,
|
||||
version=version,
|
||||
tool_name=tool_name,
|
||||
user_google_email=user_google_email,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
|
||||
builder = get_oauth21_service_builder()
|
||||
|
||||
# FastMCP handles context now - extract any session info
|
||||
session_id = None
|
||||
auth_context = None
|
||||
|
||||
if context:
|
||||
if not session_id and context:
|
||||
session_id = builder.extract_session_from_context(context)
|
||||
auth_context = context.get("auth_context")
|
||||
|
||||
auth_context = context.get("auth_context") if context else None
|
||||
|
||||
return await builder.get_authenticated_service_with_session(
|
||||
service_name=service_name,
|
||||
@@ -243,4 +256,36 @@ async def get_authenticated_google_service_oauth21(
|
||||
required_scopes=required_scopes,
|
||||
session_id=session_id,
|
||||
auth_context=auth_context,
|
||||
)
|
||||
|
||||
|
||||
async def get_authenticated_google_service_oauth21_v2(
|
||||
request: "OAuth21ServiceRequest",
|
||||
) -> Tuple[Any, str]:
|
||||
"""
|
||||
Enhanced version of get_authenticated_google_service that supports OAuth 2.1.
|
||||
|
||||
This version uses a parameter object to reduce function complexity and
|
||||
improve maintainability. It's the recommended approach for new code.
|
||||
|
||||
Args:
|
||||
request: OAuth21ServiceRequest object containing all parameters
|
||||
|
||||
Returns:
|
||||
Tuple of (service instance, actual user email)
|
||||
"""
|
||||
from auth.oauth_types import OAuth21ServiceRequest
|
||||
|
||||
# Delegate to the original function for now
|
||||
# This provides a migration path while maintaining backward compatibility
|
||||
return await get_authenticated_google_service_oauth21(
|
||||
service_name=request.service_name,
|
||||
version=request.version,
|
||||
tool_name=request.tool_name,
|
||||
user_google_email=request.user_google_email,
|
||||
required_scopes=request.required_scopes,
|
||||
session_id=request.session_id,
|
||||
auth_token_email=request.auth_token_email,
|
||||
allow_recent_auth=request.allow_recent_auth,
|
||||
context=request.context,
|
||||
)
|
||||
@@ -19,7 +19,7 @@ from urllib.parse import urlparse
|
||||
from auth.scopes import SCOPES
|
||||
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
|
||||
from auth.google_auth import handle_auth_callback, check_client_secrets
|
||||
from core.config import get_oauth_redirect_uri
|
||||
from auth.oauth_config import get_oauth_redirect_uri
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -286,21 +286,10 @@ async def handle_oauth_authorization_server(request: Request):
|
||||
)
|
||||
|
||||
config = get_oauth_config()
|
||||
base_url = config.get_oauth_base_url()
|
||||
|
||||
# Build authorization server metadata per RFC 8414
|
||||
metadata = {
|
||||
"issuer": base_url,
|
||||
"authorization_endpoint": f"{base_url}/oauth2/authorize",
|
||||
"token_endpoint": f"{base_url}/oauth2/token",
|
||||
"registration_endpoint": f"{base_url}/oauth2/register",
|
||||
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
|
||||
"response_types_supported": ["code", "token"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
||||
"scopes_supported": get_current_scopes(),
|
||||
"code_challenge_methods_supported": ["S256", "plain"],
|
||||
}
|
||||
|
||||
# Get authorization server metadata from centralized config
|
||||
# Pass scopes directly to keep all metadata generation in one place
|
||||
metadata = config.get_authorization_server_metadata(scopes=get_current_scopes())
|
||||
|
||||
logger.debug(f"Returning authorization server metadata: {metadata}")
|
||||
|
||||
@@ -363,7 +352,7 @@ async def handle_oauth_client_config(request: Request):
|
||||
"response_types": ["code"],
|
||||
"scope": " ".join(get_current_scopes()),
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"code_challenge_methods": ["S256"]
|
||||
"code_challenge_methods": config.supported_code_challenge_methods[:1] # Primary method only
|
||||
},
|
||||
headers=response_headers
|
||||
)
|
||||
@@ -411,7 +400,7 @@ async def handle_oauth_register(request: Request):
|
||||
"response_types": body.get("response_types", ["code"]),
|
||||
"scope": body.get("scope", " ".join(get_current_scopes())),
|
||||
"token_endpoint_auth_method": body.get("token_endpoint_auth_method", "client_secret_basic"),
|
||||
"code_challenge_methods": ["S256"],
|
||||
"code_challenge_methods": config.supported_code_challenge_methods,
|
||||
# Additional OAuth 2.1 fields
|
||||
"client_id_issued_at": int(time.time()),
|
||||
"registration_access_token": "not-required", # We don't implement client management
|
||||
|
||||
@@ -4,10 +4,12 @@ OAuth Configuration Management
|
||||
This module centralizes OAuth-related configuration to eliminate hardcoded values
|
||||
scattered throughout the codebase. It provides environment variable support and
|
||||
sensible defaults for all OAuth-related settings.
|
||||
|
||||
Supports both OAuth 2.0 and OAuth 2.1 with automatic client capability detection.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
|
||||
class OAuthConfig:
|
||||
@@ -29,6 +31,14 @@ class OAuthConfig:
|
||||
self.client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
self.client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||
|
||||
# OAuth 2.1 configuration
|
||||
self.oauth21_enabled = os.getenv("MCP_ENABLE_OAUTH21", "false").lower() == "true"
|
||||
self.pkce_required = self.oauth21_enabled # PKCE is mandatory in OAuth 2.1
|
||||
self.supported_code_challenge_methods = ["S256", "plain"] if not self.oauth21_enabled else ["S256"]
|
||||
|
||||
# Transport mode (will be set at runtime)
|
||||
self._transport_mode = "stdio" # Default
|
||||
|
||||
# Redirect URI configuration
|
||||
self.redirect_uri = self._get_redirect_uri()
|
||||
|
||||
@@ -187,12 +197,112 @@ class OAuthConfig:
|
||||
"base_url": self.base_url,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"client_configured": bool(self.client_id),
|
||||
"oauth21_enabled": self.oauth21_enabled,
|
||||
"pkce_required": self.pkce_required,
|
||||
"transport_mode": self._transport_mode,
|
||||
"vscode_callback_port": self.vscode_callback_port,
|
||||
"vscode_callback_hosts": self.vscode_callback_hosts,
|
||||
"development_ports": self.development_ports,
|
||||
"total_redirect_uris": len(self.get_redirect_uris()),
|
||||
"total_allowed_origins": len(self.get_allowed_origins()),
|
||||
}
|
||||
|
||||
def set_transport_mode(self, mode: str) -> None:
|
||||
"""
|
||||
Set the current transport mode for OAuth callback handling.
|
||||
|
||||
Args:
|
||||
mode: Transport mode ("stdio", "streamable-http", etc.)
|
||||
"""
|
||||
self._transport_mode = mode
|
||||
|
||||
def get_transport_mode(self) -> str:
|
||||
"""
|
||||
Get the current transport mode.
|
||||
|
||||
Returns:
|
||||
Current transport mode
|
||||
"""
|
||||
return self._transport_mode
|
||||
|
||||
def is_oauth21_enabled(self) -> bool:
|
||||
"""
|
||||
Check if OAuth 2.1 mode is enabled.
|
||||
|
||||
Returns:
|
||||
True if OAuth 2.1 is enabled
|
||||
"""
|
||||
return self.oauth21_enabled
|
||||
|
||||
def detect_oauth_version(self, request_params: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Detect OAuth version based on request parameters.
|
||||
|
||||
This method implements a conservative detection strategy:
|
||||
- Only returns "oauth21" when we have clear indicators
|
||||
- Defaults to "oauth20" for backward compatibility
|
||||
- Respects the global oauth21_enabled flag
|
||||
|
||||
Args:
|
||||
request_params: Request parameters from authorization or token request
|
||||
|
||||
Returns:
|
||||
"oauth21" or "oauth20" based on detection
|
||||
"""
|
||||
# If OAuth 2.1 is not enabled globally, always return OAuth 2.0
|
||||
if not self.oauth21_enabled:
|
||||
return "oauth20"
|
||||
|
||||
# Use the structured type for cleaner detection logic
|
||||
from auth.oauth_types import OAuthVersionDetectionParams
|
||||
params = OAuthVersionDetectionParams.from_request(request_params)
|
||||
|
||||
# Clear OAuth 2.1 indicator: PKCE is present
|
||||
if params.has_pkce:
|
||||
return "oauth21"
|
||||
|
||||
# For public clients in OAuth 2.1 mode, we require PKCE
|
||||
# But since they didn't send PKCE, fall back to OAuth 2.0
|
||||
# This ensures backward compatibility
|
||||
|
||||
# Default to OAuth 2.0 for maximum compatibility
|
||||
return "oauth20"
|
||||
|
||||
def get_authorization_server_metadata(self, scopes: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get OAuth authorization server metadata per RFC 8414.
|
||||
|
||||
Args:
|
||||
scopes: Optional list of supported scopes to include in metadata
|
||||
|
||||
Returns:
|
||||
Authorization server metadata dictionary
|
||||
"""
|
||||
metadata = {
|
||||
"issuer": self.base_url,
|
||||
"authorization_endpoint": f"{self.base_url}/oauth2/authorize",
|
||||
"token_endpoint": f"{self.base_url}/oauth2/token",
|
||||
"registration_endpoint": f"{self.base_url}/oauth2/register",
|
||||
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
|
||||
"response_types_supported": ["code", "token"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
||||
"code_challenge_methods_supported": self.supported_code_challenge_methods,
|
||||
}
|
||||
|
||||
# Include scopes if provided
|
||||
if scopes is not None:
|
||||
metadata["scopes_supported"] = scopes
|
||||
|
||||
# Add OAuth 2.1 specific metadata
|
||||
if self.oauth21_enabled:
|
||||
metadata["pkce_required"] = True
|
||||
# OAuth 2.1 deprecates implicit flow
|
||||
metadata["response_types_supported"] = ["code"]
|
||||
# OAuth 2.1 requires exact redirect URI matching
|
||||
metadata["require_exact_redirect_uri"] = True
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
@@ -244,4 +354,24 @@ def get_allowed_origins() -> List[str]:
|
||||
|
||||
def is_oauth_configured() -> bool:
|
||||
"""Check if OAuth is properly configured."""
|
||||
return get_oauth_config().is_configured()
|
||||
return get_oauth_config().is_configured()
|
||||
|
||||
|
||||
def set_transport_mode(mode: str) -> None:
|
||||
"""Set the current transport mode."""
|
||||
get_oauth_config().set_transport_mode(mode)
|
||||
|
||||
|
||||
def get_transport_mode() -> str:
|
||||
"""Get the current transport mode."""
|
||||
return get_oauth_config().get_transport_mode()
|
||||
|
||||
|
||||
def is_oauth21_enabled() -> bool:
|
||||
"""Check if OAuth 2.1 is enabled."""
|
||||
return get_oauth_config().is_oauth21_enabled()
|
||||
|
||||
|
||||
def get_oauth_redirect_uri() -> str:
|
||||
"""Get the primary OAuth redirect URI."""
|
||||
return get_oauth_config().redirect_uri
|
||||
78
auth/oauth_types.py
Normal file
78
auth/oauth_types.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Type definitions for OAuth authentication.
|
||||
|
||||
This module provides structured types for OAuth-related parameters,
|
||||
improving code maintainability and type safety.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuth21ServiceRequest:
|
||||
"""
|
||||
Encapsulates parameters for OAuth 2.1 service authentication requests.
|
||||
|
||||
This parameter object pattern reduces function complexity and makes
|
||||
it easier to extend authentication parameters in the future.
|
||||
"""
|
||||
service_name: str
|
||||
version: str
|
||||
tool_name: str
|
||||
user_google_email: str
|
||||
required_scopes: List[str]
|
||||
session_id: Optional[str] = None
|
||||
auth_token_email: Optional[str] = None
|
||||
allow_recent_auth: bool = False
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_legacy_params(self) -> dict:
|
||||
"""Convert to legacy parameter format for backward compatibility."""
|
||||
return {
|
||||
"service_name": self.service_name,
|
||||
"version": self.version,
|
||||
"tool_name": self.tool_name,
|
||||
"user_google_email": self.user_google_email,
|
||||
"required_scopes": self.required_scopes,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthVersionDetectionParams:
|
||||
"""
|
||||
Parameters used for OAuth version detection.
|
||||
|
||||
Encapsulates the various signals we use to determine
|
||||
whether a client supports OAuth 2.1 or needs OAuth 2.0.
|
||||
"""
|
||||
client_id: Optional[str] = None
|
||||
client_secret: Optional[str] = None
|
||||
code_challenge: Optional[str] = None
|
||||
code_challenge_method: Optional[str] = None
|
||||
code_verifier: Optional[str] = None
|
||||
authenticated_user: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, request_params: Dict[str, Any]) -> "OAuthVersionDetectionParams":
|
||||
"""Create from raw request parameters."""
|
||||
return cls(
|
||||
client_id=request_params.get("client_id"),
|
||||
client_secret=request_params.get("client_secret"),
|
||||
code_challenge=request_params.get("code_challenge"),
|
||||
code_challenge_method=request_params.get("code_challenge_method"),
|
||||
code_verifier=request_params.get("code_verifier"),
|
||||
authenticated_user=request_params.get("authenticated_user"),
|
||||
session_id=request_params.get("session_id"),
|
||||
)
|
||||
|
||||
@property
|
||||
def has_pkce(self) -> bool:
|
||||
"""Check if PKCE parameters are present."""
|
||||
return bool(self.code_challenge or self.code_verifier)
|
||||
|
||||
@property
|
||||
def is_public_client(self) -> bool:
|
||||
"""Check if this appears to be a public client (no secret)."""
|
||||
return bool(self.client_id and not self.client_secret)
|
||||
@@ -334,9 +334,31 @@ def require_google_service(
|
||||
# Log authentication status
|
||||
logger.debug(f"[{tool_name}] Auth: {authenticated_user or 'none'} via {auth_method or 'none'} (session: {mcp_session_id[:8] if mcp_session_id else 'none'})")
|
||||
|
||||
from auth.oauth21_integration import is_oauth21_enabled
|
||||
|
||||
from auth.oauth_config import is_oauth21_enabled, get_oauth_config
|
||||
|
||||
# Smart OAuth version detection and fallback
|
||||
use_oauth21 = False
|
||||
oauth_version = "oauth20" # Default
|
||||
|
||||
if is_oauth21_enabled():
|
||||
# OAuth 2.1 is enabled globally, check client capabilities
|
||||
# Try to detect from context if this is an OAuth 2.1 capable client
|
||||
config = get_oauth_config()
|
||||
|
||||
# Build request params from context for version detection
|
||||
request_params = {}
|
||||
if authenticated_user:
|
||||
request_params["authenticated_user"] = authenticated_user
|
||||
if mcp_session_id:
|
||||
request_params["session_id"] = mcp_session_id
|
||||
|
||||
# Detect OAuth version based on client capabilities
|
||||
oauth_version = config.detect_oauth_version(request_params)
|
||||
use_oauth21 = (oauth_version == "oauth21")
|
||||
|
||||
logger.debug(f"[{tool_name}] OAuth version detected: {oauth_version}, will use OAuth 2.1: {use_oauth21}")
|
||||
|
||||
if use_oauth21:
|
||||
logger.debug(f"[{tool_name}] Using OAuth 2.1 flow")
|
||||
# The downstream get_authenticated_google_service_oauth21 will handle
|
||||
# whether the user's token is valid for the requested resource.
|
||||
@@ -352,8 +374,8 @@ def require_google_service(
|
||||
allow_recent_auth=False,
|
||||
)
|
||||
else:
|
||||
# If OAuth 2.1 is not enabled, always use the legacy authentication method.
|
||||
logger.debug(f"[{tool_name}] Using legacy OAuth flow")
|
||||
# Use legacy OAuth 2.0 authentication
|
||||
logger.debug(f"[{tool_name}] Using legacy OAuth 2.0 flow")
|
||||
service, actual_user_email = await get_authenticated_google_service(
|
||||
service_name=service_name,
|
||||
version=service_version,
|
||||
|
||||
Reference in New Issue
Block a user