refactor oauth2.1 support to fastmcp native
This commit is contained in:
@@ -9,6 +9,7 @@ Supports both OAuth 2.0 and OAuth 2.1 with automatic client capability detection
|
||||
"""
|
||||
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
|
||||
@@ -49,6 +50,10 @@ class OAuthConfig:
|
||||
|
||||
# Redirect URI configuration
|
||||
self.redirect_uri = self._get_redirect_uri()
|
||||
self.redirect_path = self._get_redirect_path(self.redirect_uri)
|
||||
|
||||
# Ensure FastMCP's Google provider picks up our existing configuration
|
||||
self._apply_fastmcp_google_env()
|
||||
|
||||
def _get_redirect_uri(self) -> str:
|
||||
"""
|
||||
@@ -62,6 +67,32 @@ class OAuthConfig:
|
||||
return explicit_uri
|
||||
return f"{self.base_url}/oauth2callback"
|
||||
|
||||
@staticmethod
|
||||
def _get_redirect_path(uri: str) -> str:
|
||||
"""Extract the redirect path from a full redirect URI."""
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme or parsed.netloc:
|
||||
path = parsed.path or "/oauth2callback"
|
||||
else:
|
||||
# If the value was already a path, ensure it starts with '/'
|
||||
path = uri if uri.startswith("/") else f"/{uri}"
|
||||
return path or "/oauth2callback"
|
||||
|
||||
def _apply_fastmcp_google_env(self) -> None:
|
||||
"""Mirror legacy GOOGLE_* env vars into FastMCP Google provider settings."""
|
||||
if not self.client_id:
|
||||
return
|
||||
|
||||
def _set_if_absent(key: str, value: Optional[str]) -> None:
|
||||
if value and key not in os.environ:
|
||||
os.environ[key] = value
|
||||
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH", "fastmcp.server.auth.providers.google.GoogleProvider" if self.oauth21_enabled else None)
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_CLIENT_ID", self.client_id)
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_CLIENT_SECRET", self.client_secret)
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_BASE_URL", self.get_oauth_base_url())
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_REDIRECT_PATH", self.redirect_path)
|
||||
|
||||
def get_redirect_uris(self) -> List[str]:
|
||||
"""
|
||||
Get all valid OAuth redirect URIs.
|
||||
@@ -156,6 +187,7 @@ class OAuthConfig:
|
||||
"external_url": self.external_url,
|
||||
"effective_oauth_url": self.get_oauth_base_url(),
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"redirect_path": self.redirect_path,
|
||||
"client_configured": bool(self.client_id),
|
||||
"oauth21_enabled": self.oauth21_enabled,
|
||||
"pkce_required": self.pkce_required,
|
||||
@@ -350,4 +382,4 @@ def get_oauth_redirect_uri() -> str:
|
||||
|
||||
def is_stateless_mode() -> bool:
|
||||
"""Check if stateless mode is enabled."""
|
||||
return get_oauth_config().stateless_mode
|
||||
return get_oauth_config().stateless_mode
|
||||
|
||||
Reference in New Issue
Block a user