oauth2.1 working - quick flow in mcp inspector

This commit is contained in:
Taylor Wilsdon
2025-08-02 09:52:16 -04:00
parent 8d053680c0
commit dad52829f8
18 changed files with 5211 additions and 106 deletions

View File

@@ -1,14 +1,16 @@
import logging
import os
from typing import Optional
from typing import Any, Optional
from importlib import metadata
from fastapi import Header
from fastapi.responses import HTMLResponse
from fastapi.responses import JSONResponse
from mcp.server.fastmcp import FastMCP
from starlette.requests import Request
from starlette.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from auth.google_auth import handle_auth_callback, start_auth_flow, check_client_secrets
from auth.oauth_callback_server import get_oauth_redirect_uri, ensure_oauth_callback_available
@@ -78,8 +80,34 @@ USER_GOOGLE_EMAIL = os.getenv("USER_GOOGLE_EMAIL", None)
# Transport mode detection (will be set by main.py)
_current_transport_mode = "stdio" # Default to stdio
# Basic MCP server instance
server = FastMCP(
# OAuth 2.1 authentication layer instance
_auth_layer: Optional[AuthCompatibilityLayer] = None
# Create middleware configuration
from starlette.middleware import Middleware
cors_middleware = Middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify allowed origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Custom FastMCP that adds CORS to streamable HTTP
class CORSEnabledFastMCP(FastMCP):
def streamable_http_app(self) -> "Starlette":
"""Override to add CORS middleware to the app."""
app = super().streamable_http_app()
# Add CORS as the first middleware
app.user_middleware.insert(0, cors_middleware)
# Rebuild middleware stack
app.middleware_stack = app.build_middleware_stack()
logger.info("Added CORS middleware to streamable HTTP app")
return app
# Basic MCP server instance with CORS support
server = CORSEnabledFastMCP(
name="google_workspace",
port=WORKSPACE_MCP_PORT,
host="0.0.0.0"
@@ -98,33 +126,39 @@ def get_oauth_redirect_uri_for_current_mode() -> str:
async def initialize_oauth21_auth() -> Optional[AuthCompatibilityLayer]:
"""Initialize OAuth 2.1 authentication layer if available and configured."""
global _auth_layer
if not OAUTH21_AVAILABLE:
logger.info("OAuth 2.1 not available (dependencies not installed)")
return None
try:
# Set the resource URL environment variable to match the MCP server URL
port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)))
base_uri = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
os.environ["OAUTH2_RESOURCE_URL"] = f"{base_uri}:{port}/mcp"
os.environ["OAUTH2_PROXY_BASE_URL"] = f"{base_uri}:{port}"
# Create authentication configuration
auth_config = AuthConfig()
if auth_config.is_oauth2_enabled():
logger.info(f"Initializing OAuth 2.1 authentication: {auth_config.get_effective_auth_mode()}")
_auth_layer = AuthCompatibilityLayer(auth_config)
await _auth_layer.start()
# Add middleware if HTTP transport is being used
if _current_transport_mode == "http" or _current_transport_mode == "streamable-http":
middleware = _auth_layer.create_enhanced_middleware()
if middleware and hasattr(server, 'app'):
server.app.add_middleware(type(middleware), **middleware.__dict__)
logger.info("Added OAuth 2.1 middleware to FastAPI app")
logger.info("OAuth 2.1 authentication initialized successfully")
else:
logger.info("OAuth 2.1 not configured, using legacy authentication only")
return _auth_layer
except Exception as e:
logger.error(f"Failed to initialize OAuth 2.1 authentication: {e}")
return None
@@ -132,7 +166,7 @@ async def initialize_oauth21_auth() -> Optional[AuthCompatibilityLayer]:
async def shutdown_oauth21_auth():
"""Shutdown OAuth 2.1 authentication layer."""
global _auth_layer
if _auth_layer:
try:
await _auth_layer.stop()
@@ -146,6 +180,7 @@ def get_auth_layer() -> Optional[AuthCompatibilityLayer]:
"""Get the global authentication layer instance."""
return _auth_layer
# Health check endpoint
@server.custom_route("/health", methods=["GET"])
async def health_check(request: Request):
@@ -271,25 +306,23 @@ async def start_google_auth(
# OAuth 2.1 Discovery Endpoints
@server.custom_route("/.well-known/oauth-protected-resource", methods=["GET"])
@server.custom_route("/.well-known/oauth-protected-resource", methods=["GET", "OPTIONS"])
async def oauth_protected_resource(request: Request):
"""
OAuth 2.1 Protected Resource Metadata endpoint per RFC9728.
Returns metadata about this protected resource including authorization servers.
"""
from fastapi.responses import JSONResponse
auth_layer = get_auth_layer()
if not auth_layer or not auth_layer.config.is_oauth2_enabled():
return JSONResponse(
status_code=404,
content={"error": "OAuth 2.1 not configured"}
)
try:
discovery_service = auth_layer.oauth2_handler.discovery
metadata = await discovery_service.get_protected_resource_metadata()
return JSONResponse(
content=metadata,
headers={
@@ -305,43 +338,105 @@ async def oauth_protected_resource(request: Request):
)
@server.custom_route("/.well-known/oauth-authorization-server", methods=["GET"])
@server.custom_route("/auth/discovery/authorization-server/{server_host:path}", methods=["GET", "OPTIONS"])
async def proxy_authorization_server_discovery(request: Request, server_host: str):
"""
Proxy authorization server discovery requests to avoid CORS issues.
This allows the client to discover external authorization servers through our server.
"""
import aiohttp
from fastapi.responses import JSONResponse
# Handle OPTIONS request for CORS
if request.method == "OPTIONS":
return JSONResponse(
content={},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
}
)
# Build the discovery URL
if not server_host.startswith(('http://', 'https://')):
server_host = f"https://{server_host}"
discovery_urls = [
f"{server_host}/.well-known/oauth-authorization-server",
f"{server_host}/.well-known/openid-configuration",
]
# Try to fetch from the authorization server
async with aiohttp.ClientSession() as session:
for url in discovery_urls:
try:
async with session.get(url) as response:
if response.status == 200:
metadata = await response.json()
return JSONResponse(
content=metadata,
headers={
"Access-Control-Allow-Origin": "*",
"Cache-Control": "public, max-age=3600",
}
)
except Exception as e:
logger.debug(f"Failed to fetch from {url}: {e}")
continue
return JSONResponse(
status_code=404,
content={"error": "Authorization server metadata not found"},
headers={"Access-Control-Allow-Origin": "*"}
)
@server.custom_route("/.well-known/oauth-authorization-server", methods=["GET", "OPTIONS"])
async def oauth_authorization_server(request: Request):
"""
OAuth 2.1 Authorization Server Metadata endpoint per RFC8414.
Returns metadata about the authorization server for this resource.
"""
from fastapi.responses import JSONResponse
auth_layer = get_auth_layer()
if not auth_layer or not auth_layer.config.is_oauth2_enabled():
return JSONResponse(
status_code=404,
content={"error": "OAuth 2.1 not configured"}
)
try:
discovery_service = auth_layer.oauth2_handler.discovery
auth_server_url = auth_layer.config.oauth2.authorization_server_url
if not auth_server_url:
return JSONResponse(
status_code=404,
content={"error": "No authorization server configured"}
)
metadata = await discovery_service.get_authorization_server_metadata(auth_server_url)
# Override issuer to point to this server for MCP-specific metadata
base_url = f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}"
metadata["issuer"] = base_url
metadata["authorization_endpoint"] = f"{auth_server_url}/o/oauth2/v2/auth"
metadata["token_endpoint"] = f"{auth_server_url}/token"
# Use our proxy for token endpoint to avoid CORS issues
metadata["token_endpoint"] = f"{base_url}/oauth2/token"
# Also proxy revocation and introspection if present
if "revocation_endpoint" in metadata:
metadata["revocation_endpoint"] = f"{base_url}/oauth2/revoke"
if "introspection_endpoint" in metadata:
metadata["introspection_endpoint"] = f"{base_url}/oauth2/introspect"
# Add dynamic client registration support
metadata["registration_endpoint"] = f"{base_url}/oauth2/register"
metadata["client_registration_types_supported"] = ["automatic"]
return JSONResponse(
content=metadata,
headers={
"Content-Type": "application/json",
"Content-Type": "application/json",
"Cache-Control": "public, max-age=3600",
}
)
@@ -353,30 +448,265 @@ async def oauth_authorization_server(request: Request):
)
@server.custom_route("/oauth2/authorize", methods=["GET"])
@server.custom_route("/.well-known/oauth-client", methods=["GET", "OPTIONS"])
async def oauth_client_info(request: Request):
"""
Provide pre-configured OAuth client information.
This is a custom endpoint to help clients that can't use dynamic registration.
"""
auth_layer = get_auth_layer()
if not auth_layer or not auth_layer.config.is_oauth2_enabled():
return JSONResponse(
status_code=404,
content={"error": "OAuth 2.1 not configured"}
)
# Handle OPTIONS for CORS
if request.method == "OPTIONS":
return JSONResponse(
content={},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
}
)
# Get client configuration
oauth_config = auth_layer.config.oauth2
# Return client information (without the secret for security)
client_info = {
"client_id": oauth_config.client_id,
"client_name": "MCP Server OAuth Client",
"redirect_uris": [
f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/oauth2callback"
],
"grant_types": ["authorization_code"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
"scope": " ".join(oauth_config.required_scopes) if oauth_config.required_scopes else "openid email profile",
"registration_required": True,
"registration_instructions": "Pre-configure your OAuth client with Google Console at https://console.cloud.google.com"
}
return JSONResponse(
content=client_info,
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
"Cache-Control": "public, max-age=3600",
}
)
@server.custom_route("/oauth2/register", methods=["POST", "OPTIONS"])
async def oauth2_dynamic_client_registration(request: Request):
"""
Dynamic Client Registration endpoint per RFC7591.
This proxies the client's registration to use our pre-configured Google OAuth credentials.
"""
from fastapi.responses import JSONResponse
import json
import uuid
from datetime import datetime
# Handle OPTIONS for CORS
if request.method == "OPTIONS":
return JSONResponse(
content={},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type",
}
)
auth_layer = get_auth_layer()
if not auth_layer or not auth_layer.config.is_oauth2_enabled():
return JSONResponse(
status_code=400,
content={"error": "invalid_request", "error_description": "OAuth 2.1 not configured"}
)
try:
# Parse the registration request
body = await request.body()
registration_request = json.loads(body) if body else {}
# Get our pre-configured OAuth credentials
oauth_config = auth_layer.config.oauth2
# Generate a unique client identifier for this registration
client_instance_id = str(uuid.uuid4())
# Build the registration response
# We use our pre-configured Google OAuth credentials but give the client a unique ID
registration_response = {
"client_id": oauth_config.client_id, # Use our actual Google OAuth client ID
"client_secret": oauth_config.client_secret, # Provide the secret for confidential clients
"client_id_issued_at": int(datetime.now().timestamp()),
"client_instance_id": client_instance_id,
"registration_access_token": client_instance_id, # Use instance ID as access token
"registration_client_uri": f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/oauth2/register/{client_instance_id}",
# Echo back what the client requested with our constraints
"redirect_uris": registration_request.get("redirect_uris", []),
"token_endpoint_auth_method": registration_request.get("token_endpoint_auth_method", "client_secret_post"),
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"client_name": registration_request.get("client_name", "MCP OAuth Client"),
"scope": registration_request.get("scope", " ".join(oauth_config.required_scopes)),
# Additional metadata
"client_uri": registration_request.get("client_uri"),
"logo_uri": registration_request.get("logo_uri"),
"tos_uri": registration_request.get("tos_uri"),
"policy_uri": registration_request.get("policy_uri"),
}
# Remove None values
registration_response = {k: v for k, v in registration_response.items() if v is not None}
logger.info(f"Registered dynamic client with instance ID: {client_instance_id}")
return JSONResponse(
status_code=201,
content=registration_response,
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
"Cache-Control": "no-store",
}
)
except json.JSONDecodeError:
return JSONResponse(
status_code=400,
content={"error": "invalid_request", "error_description": "Invalid JSON in request body"},
headers={"Access-Control-Allow-Origin": "*"}
)
except Exception as e:
logger.error(f"Error in dynamic client registration: {e}")
return JSONResponse(
status_code=500,
content={"error": "server_error", "error_description": "Internal server error"},
headers={"Access-Control-Allow-Origin": "*"}
)
@server.custom_route("/oauth2/token", methods=["POST", "OPTIONS"])
async def oauth2_token_proxy(request: Request):
"""
Token exchange proxy endpoint to avoid CORS issues.
Forwards token requests to Google's OAuth token endpoint.
"""
import aiohttp
import json
from fastapi.responses import JSONResponse
# Handle OPTIONS for CORS
if request.method == "OPTIONS":
return JSONResponse(
content={},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
}
)
auth_layer = get_auth_layer()
if not auth_layer or not auth_layer.config.is_oauth2_enabled():
return JSONResponse(
status_code=400,
content={"error": "invalid_request", "error_description": "OAuth 2.1 not configured"},
headers={"Access-Control-Allow-Origin": "*"}
)
try:
# Get the request body and headers
body = await request.body()
content_type = request.headers.get("content-type", "application/x-www-form-urlencoded")
# Always use the correct Google OAuth token endpoint
token_endpoint = "https://oauth2.googleapis.com/token"
# Forward the request to Google's token endpoint
async with aiohttp.ClientSession() as session:
headers = {"Content-Type": content_type}
async with session.post(token_endpoint, data=body, headers=headers) as response:
# Read response as text first to handle both JSON and HTML errors
response_text = await response.text()
# Try to parse as JSON
try:
response_data = json.loads(response_text)
except json.JSONDecodeError:
# If not JSON, it's likely an HTML error page
logger.error(f"Token exchange failed with HTML response: {response.status}")
logger.error(f"Response preview: {response_text[:500]}")
response_data = {
"error": "invalid_request",
"error_description": f"Token endpoint returned HTML error (status {response.status})"
}
# Log for debugging
if response.status != 200:
logger.error(f"Token exchange failed: {response.status} - {response_data}")
logger.error(f"Request body: {body.decode('utf-8')}")
else:
logger.info("Token exchange successful")
# Return the response with CORS headers
return JSONResponse(
status_code=response.status,
content=response_data,
headers={
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
"Cache-Control": "no-store",
}
)
except Exception as e:
logger.error(f"Error in token proxy: {e}")
return JSONResponse(
status_code=500,
content={"error": "server_error", "error_description": str(e)},
headers={"Access-Control-Allow-Origin": "*"}
)
@server.custom_route("/oauth2/authorize", methods=["GET", "OPTIONS"])
async def oauth2_authorize(request: Request):
"""
OAuth 2.1 authorization endpoint for MCP clients.
Redirects to the configured authorization server with proper parameters.
"""
# Handle OPTIONS request for CORS preflight
if request.method == "OPTIONS":
return JSONResponse()
from fastapi.responses import RedirectResponse
from urllib.parse import urlencode
auth_layer = get_auth_layer()
if not auth_layer or not auth_layer.config.is_oauth2_enabled():
return create_error_response("OAuth 2.1 not configured")
try:
# Extract authorization parameters
params = dict(request.query_params)
# Validate required parameters
required_params = ["client_id", "redirect_uri", "response_type", "code_challenge", "code_challenge_method"]
missing_params = [p for p in required_params if p not in params]
if missing_params:
return create_error_response(f"Missing required parameters: {', '.join(missing_params)}")
# Build authorization URL
auth_server_url = auth_layer.config.oauth2.authorization_server_url
auth_url, state, code_verifier = await auth_layer.oauth2_handler.create_authorization_url(
@@ -385,52 +715,50 @@ async def oauth2_authorize(request: Request):
state=params.get("state"),
additional_params={k: v for k, v in params.items() if k not in ["scope", "state"]}
)
return RedirectResponse(url=auth_url)
except Exception as e:
logger.error(f"Error in OAuth 2.1 authorize endpoint: {e}")
return create_error_response(f"Authorization failed: {str(e)}")
@server.custom_route("/oauth2/token", methods=["POST"])
@server.custom_route("/oauth2/token", methods=["POST", "OPTIONS"])
async def oauth2_token(request: Request):
"""
OAuth 2.1 token endpoint for MCP clients.
Exchanges authorization codes for access tokens.
"""
from fastapi.responses import JSONResponse
auth_layer = get_auth_layer()
if not auth_layer or not auth_layer.config.is_oauth2_enabled():
return JSONResponse(
status_code=404,
content={"error": "OAuth 2.1 not configured"}
)
try:
# Parse form data
form_data = await request.form()
grant_type = form_data.get("grant_type")
if grant_type == "authorization_code":
# Handle authorization code exchange
code = form_data.get("code")
code_verifier = form_data.get("code_verifier")
code_verifier = form_data.get("code_verifier")
redirect_uri = form_data.get("redirect_uri")
if not all([code, code_verifier, redirect_uri]):
return JSONResponse(
status_code=400,
content={"error": "invalid_request", "error_description": "Missing required parameters"}
)
session_id, session = await auth_layer.oauth2_handler.exchange_code_for_session(
authorization_code=code,
code_verifier=code_verifier,
redirect_uri=redirect_uri
)
# Return token response
token_response = {
"access_token": session.token_info["access_token"],
@@ -439,12 +767,12 @@ async def oauth2_token(request: Request):
"scope": " ".join(session.scopes),
"session_id": session_id,
}
if "refresh_token" in session.token_info:
token_response["refresh_token"] = session.token_info["refresh_token"]
return JSONResponse(content=token_response)
elif grant_type == "refresh_token":
# Handle token refresh
refresh_token = form_data.get("refresh_token")
@@ -453,20 +781,20 @@ async def oauth2_token(request: Request):
status_code=400,
content={"error": "invalid_request", "error_description": "Missing refresh_token"}
)
# Find session by refresh token (simplified implementation)
# In production, you'd want a more robust refresh token lookup
return JSONResponse(
status_code=501,
content={"error": "unsupported_grant_type", "error_description": "Refresh token flow not yet implemented"}
)
else:
return JSONResponse(
status_code=400,
content={"error": "unsupported_grant_type", "error_description": f"Grant type '{grant_type}' not supported"}
)
except Exception as e:
logger.error(f"Error in OAuth 2.1 token endpoint: {e}")
return JSONResponse(