oauth2.1 stub

This commit is contained in:
Taylor Wilsdon
2025-07-19 19:33:55 -04:00
parent b3e64fc67c
commit ad981dbd98
4 changed files with 359 additions and 16 deletions

View File

@@ -14,6 +14,18 @@ from auth.google_auth import handle_auth_callback, start_auth_flow, check_client
from auth.oauth_callback_server import get_oauth_redirect_uri, ensure_oauth_callback_available
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
# OAuth 2.1 imports (optional)
try:
from auth.oauth21.config import AuthConfig, create_default_oauth2_config
from auth.oauth21.compat import AuthCompatibilityLayer
OAUTH21_AVAILABLE = True
except ImportError as e:
logger = logging.getLogger(__name__)
logger.debug(f"OAuth 2.1 not available: {e}")
OAUTH21_AVAILABLE = False
AuthCompatibilityLayer = None
AuthConfig = None
# Import shared configuration
from auth.scopes import (
OAUTH_STATE_TO_SESSION_ID_MAP,
@@ -83,6 +95,57 @@ def get_oauth_redirect_uri_for_current_mode() -> str:
"""Get OAuth redirect URI based on current transport mode."""
return get_oauth_redirect_uri(WORKSPACE_MCP_PORT, WORKSPACE_MCP_BASE_URI)
async def initialize_oauth21_auth() -> Optional[AuthCompatibilityLayer]:
"""Initialize OAuth 2.1 authentication layer if available and configured."""
global _auth_layer
if not OAUTH21_AVAILABLE:
logger.info("OAuth 2.1 not available (dependencies not installed)")
return None
try:
# Create authentication configuration
auth_config = AuthConfig()
if auth_config.is_oauth2_enabled():
logger.info(f"Initializing OAuth 2.1 authentication: {auth_config.get_effective_auth_mode()}")
_auth_layer = AuthCompatibilityLayer(auth_config)
await _auth_layer.start()
# Add middleware if HTTP transport is being used
if _current_transport_mode == "http" or _current_transport_mode == "streamable-http":
middleware = _auth_layer.create_enhanced_middleware()
if middleware and hasattr(server, 'app'):
server.app.add_middleware(type(middleware), **middleware.__dict__)
logger.info("Added OAuth 2.1 middleware to FastAPI app")
logger.info("OAuth 2.1 authentication initialized successfully")
else:
logger.info("OAuth 2.1 not configured, using legacy authentication only")
return _auth_layer
except Exception as e:
logger.error(f"Failed to initialize OAuth 2.1 authentication: {e}")
return None
async def shutdown_oauth21_auth():
"""Shutdown OAuth 2.1 authentication layer."""
global _auth_layer
if _auth_layer:
try:
await _auth_layer.stop()
logger.info("OAuth 2.1 authentication stopped")
except Exception as e:
logger.error(f"Error stopping OAuth 2.1 authentication: {e}")
finally:
_auth_layer = None
def get_auth_layer() -> Optional[AuthCompatibilityLayer]:
"""Get the global authentication layer instance."""
return _auth_layer
# Health check endpoint
@server.custom_route("/health", methods=["GET"])
async def health_check(request: Request):
@@ -211,3 +274,208 @@ async def start_google_auth(
redirect_uri=redirect_uri
)
return auth_result
# OAuth 2.1 Discovery Endpoints
@server.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
async def oauth_protected_resource(request: Request):
"""
OAuth 2.1 Protected Resource Metadata endpoint per RFC9728.
Returns metadata about this protected resource including authorization servers.
"""
from fastapi.responses import JSONResponse
auth_layer = get_auth_layer()
if not auth_layer or not auth_layer.config.is_oauth2_enabled():
return JSONResponse(
status_code=404,
content={"error": "OAuth 2.1 not configured"}
)
try:
discovery_service = auth_layer.oauth2_handler.discovery
metadata = await discovery_service.get_protected_resource_metadata()
return JSONResponse(
content=metadata,
headers={
"Content-Type": "application/json",
"Cache-Control": "public, max-age=3600",
}
)
except Exception as e:
logger.error(f"Error serving protected resource metadata: {e}")
return JSONResponse(
status_code=500,
content={"error": "Internal server error"}
)
@server.custom_route("/.well-known/oauth-authorization-server", methods=["GET"])
async def oauth_authorization_server(request: Request):
"""
OAuth 2.1 Authorization Server Metadata endpoint per RFC8414.
Returns metadata about the authorization server for this resource.
"""
from fastapi.responses import JSONResponse
auth_layer = get_auth_layer()
if not auth_layer or not auth_layer.config.is_oauth2_enabled():
return JSONResponse(
status_code=404,
content={"error": "OAuth 2.1 not configured"}
)
try:
discovery_service = auth_layer.oauth2_handler.discovery
auth_server_url = auth_layer.config.oauth2.authorization_server_url
if not auth_server_url:
return JSONResponse(
status_code=404,
content={"error": "No authorization server configured"}
)
metadata = await discovery_service.get_authorization_server_metadata(auth_server_url)
# Override issuer to point to this server for MCP-specific metadata
base_url = f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}"
metadata["issuer"] = base_url
metadata["authorization_endpoint"] = f"{auth_server_url}/o/oauth2/v2/auth"
metadata["token_endpoint"] = f"{auth_server_url}/token"
return JSONResponse(
content=metadata,
headers={
"Content-Type": "application/json",
"Cache-Control": "public, max-age=3600",
}
)
except Exception as e:
logger.error(f"Error serving authorization server metadata: {e}")
return JSONResponse(
status_code=500,
content={"error": "Internal server error"}
)
@server.custom_route("/oauth2/authorize", methods=["GET"])
async def oauth2_authorize(request: Request):
"""
OAuth 2.1 authorization endpoint for MCP clients.
Redirects to the configured authorization server with proper parameters.
"""
from fastapi.responses import RedirectResponse
from urllib.parse import urlencode
auth_layer = get_auth_layer()
if not auth_layer or not auth_layer.config.is_oauth2_enabled():
return create_error_response("OAuth 2.1 not configured")
try:
# Extract authorization parameters
params = dict(request.query_params)
# Validate required parameters
required_params = ["client_id", "redirect_uri", "response_type", "code_challenge", "code_challenge_method"]
missing_params = [p for p in required_params if p not in params]
if missing_params:
return create_error_response(f"Missing required parameters: {', '.join(missing_params)}")
# Build authorization URL
auth_server_url = auth_layer.config.oauth2.authorization_server_url
auth_url, state, code_verifier = await auth_layer.oauth2_handler.create_authorization_url(
redirect_uri=params["redirect_uri"],
scopes=params.get("scope", "").split(),
state=params.get("state"),
additional_params={k: v for k, v in params.items() if k not in ["scope", "state"]}
)
return RedirectResponse(url=auth_url)
except Exception as e:
logger.error(f"Error in OAuth 2.1 authorize endpoint: {e}")
return create_error_response(f"Authorization failed: {str(e)}")
@server.custom_route("/oauth2/token", methods=["POST"])
async def oauth2_token(request: Request):
"""
OAuth 2.1 token endpoint for MCP clients.
Exchanges authorization codes for access tokens.
"""
from fastapi.responses import JSONResponse
auth_layer = get_auth_layer()
if not auth_layer or not auth_layer.config.is_oauth2_enabled():
return JSONResponse(
status_code=404,
content={"error": "OAuth 2.1 not configured"}
)
try:
# Parse form data
form_data = await request.form()
grant_type = form_data.get("grant_type")
if grant_type == "authorization_code":
# Handle authorization code exchange
code = form_data.get("code")
code_verifier = form_data.get("code_verifier")
redirect_uri = form_data.get("redirect_uri")
if not all([code, code_verifier, redirect_uri]):
return JSONResponse(
status_code=400,
content={"error": "invalid_request", "error_description": "Missing required parameters"}
)
session_id, session = await auth_layer.oauth2_handler.exchange_code_for_session(
authorization_code=code,
code_verifier=code_verifier,
redirect_uri=redirect_uri
)
# Return token response
token_response = {
"access_token": session.token_info["access_token"],
"token_type": "Bearer",
"expires_in": 3600, # 1 hour
"scope": " ".join(session.scopes),
"session_id": session_id,
}
if "refresh_token" in session.token_info:
token_response["refresh_token"] = session.token_info["refresh_token"]
return JSONResponse(content=token_response)
elif grant_type == "refresh_token":
# Handle token refresh
refresh_token = form_data.get("refresh_token")
if not refresh_token:
return JSONResponse(
status_code=400,
content={"error": "invalid_request", "error_description": "Missing refresh_token"}
)
# Find session by refresh token (simplified implementation)
# In production, you'd want a more robust refresh token lookup
return JSONResponse(
status_code=501,
content={"error": "unsupported_grant_type", "error_description": "Refresh token flow not yet implemented"}
)
else:
return JSONResponse(
status_code=400,
content={"error": "unsupported_grant_type", "error_description": f"Grant type '{grant_type}' not supported"}
)
except Exception as e:
logger.error(f"Error in OAuth 2.1 token endpoint: {e}")
return JSONResponse(
status_code=500,
content={"error": "server_error", "error_description": "Internal server error"}
)