refactor oauth2.1 support to fastmcp native
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional
|
||||
from importlib import metadata
|
||||
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
@@ -8,13 +8,13 @@ from starlette.requests import Request
|
||||
from starlette.middleware import Middleware
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.auth.providers.google import GoogleProvider
|
||||
|
||||
from auth.oauth21_session_store import get_oauth21_session_store, set_auth_provider
|
||||
from auth.google_auth import handle_auth_callback, start_auth_flow, check_client_secrets
|
||||
from auth.mcp_session_middleware import MCPSessionMiddleware
|
||||
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
|
||||
from auth.auth_info_middleware import AuthInfoMiddleware
|
||||
from auth.fastmcp_google_auth import GoogleWorkspaceAuthProvider
|
||||
from auth.scopes import SCOPES, get_current_scopes # noqa
|
||||
from core.config import (
|
||||
USER_GOOGLE_EMAIL,
|
||||
@@ -23,17 +23,11 @@ from core.config import (
|
||||
get_oauth_redirect_uri as get_oauth_redirect_uri_for_current_mode,
|
||||
)
|
||||
|
||||
try:
|
||||
from auth.google_remote_auth_provider import GoogleRemoteAuthProvider
|
||||
GOOGLE_REMOTE_AUTH_AVAILABLE = True
|
||||
except ImportError:
|
||||
GOOGLE_REMOTE_AUTH_AVAILABLE = False
|
||||
GoogleRemoteAuthProvider = None
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_auth_provider: Optional[Union[GoogleWorkspaceAuthProvider, GoogleRemoteAuthProvider]] = None
|
||||
_auth_provider: Optional[GoogleProvider] = None
|
||||
_legacy_callback_registered = False
|
||||
|
||||
session_middleware = Middleware(MCPSessionMiddleware)
|
||||
|
||||
@@ -67,6 +61,14 @@ def set_transport_mode(mode: str):
|
||||
_set_transport_mode(mode)
|
||||
logger.info(f"Transport: {mode}")
|
||||
|
||||
|
||||
def _ensure_legacy_callback_route() -> None:
|
||||
global _legacy_callback_registered
|
||||
if _legacy_callback_registered:
|
||||
return
|
||||
server.custom_route("/oauth2callback", methods=["GET"])(legacy_oauth2_callback)
|
||||
_legacy_callback_registered = True
|
||||
|
||||
def configure_server_for_http():
|
||||
"""
|
||||
Configures the authentication provider for HTTP transport.
|
||||
@@ -91,28 +93,31 @@ def configure_server_for_http():
|
||||
logger.warning("OAuth 2.1 enabled but OAuth credentials not configured")
|
||||
return
|
||||
|
||||
if not GOOGLE_REMOTE_AUTH_AVAILABLE:
|
||||
logger.error("CRITICAL: OAuth 2.1 enabled but FastMCP 2.11.1+ is not properly installed.")
|
||||
logger.error("Please run: uv sync --frozen")
|
||||
raise RuntimeError(
|
||||
"OAuth 2.1 requires FastMCP 2.11.1+ with RemoteAuthProvider support. "
|
||||
"Please reinstall dependencies using 'uv sync --frozen'."
|
||||
)
|
||||
|
||||
logger.info("OAuth 2.1 enabled with automatic OAuth 2.0 fallback for legacy clients")
|
||||
try:
|
||||
_auth_provider = GoogleRemoteAuthProvider()
|
||||
server.auth = _auth_provider
|
||||
set_auth_provider(_auth_provider)
|
||||
logger.debug("OAuth 2.1 authentication enabled")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize GoogleRemoteAuthProvider: {e}", exc_info=True)
|
||||
required_scopes: List[str] = sorted(get_current_scopes())
|
||||
provider = GoogleProvider(
|
||||
client_id=config.client_id,
|
||||
client_secret=config.client_secret,
|
||||
base_url=config.get_oauth_base_url(),
|
||||
redirect_path=config.redirect_path,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
server.auth = provider
|
||||
set_auth_provider(provider)
|
||||
logger.info("OAuth 2.1 enabled using FastMCP GoogleProvider")
|
||||
_auth_provider = provider
|
||||
except Exception as exc:
|
||||
logger.error("Failed to initialize FastMCP GoogleProvider: %s", exc, exc_info=True)
|
||||
raise
|
||||
else:
|
||||
logger.info("OAuth 2.0 mode - Server will use legacy authentication.")
|
||||
server.auth = None
|
||||
_auth_provider = None
|
||||
set_auth_provider(None)
|
||||
_ensure_legacy_callback_route()
|
||||
|
||||
def get_auth_provider() -> Optional[Union[GoogleWorkspaceAuthProvider, GoogleRemoteAuthProvider]]:
|
||||
|
||||
def get_auth_provider() -> Optional[GoogleProvider]:
|
||||
"""Gets the global authentication provider instance."""
|
||||
return _auth_provider
|
||||
|
||||
@@ -129,8 +134,7 @@ async def health_check(request: Request):
|
||||
"transport": get_transport_mode()
|
||||
})
|
||||
|
||||
@server.custom_route("/oauth2callback", methods=["GET"])
|
||||
async def oauth2_callback(request: Request) -> HTMLResponse:
|
||||
async def legacy_oauth2_callback(request: Request) -> HTMLResponse:
|
||||
state = request.query_params.get("state")
|
||||
code = request.query_params.get("code")
|
||||
error = request.query_params.get("error")
|
||||
|
||||
Reference in New Issue
Block a user