minimal fastmcp based oauth working
This commit is contained in:
445
auth/fastmcp_google_auth.py
Normal file
445
auth/fastmcp_google_auth.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Google Workspace Authentication Provider for FastMCP
|
||||
|
||||
This module implements OAuth 2.1 authentication for Google Workspace using FastMCP's
|
||||
built-in authentication patterns. It acts as a Resource Server (RS) that trusts
|
||||
Google as the Authorization Server (AS).
|
||||
|
||||
Key features:
|
||||
- JWT token verification using Google's public keys
|
||||
- Discovery metadata endpoints for MCP protocol compliance
|
||||
- CORS proxy endpoints to work around Google's CORS limitations
|
||||
- Session bridging to Google credentials for API access
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import aiohttp
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
from starlette.requests import Request
|
||||
|
||||
from fastmcp.server.auth.auth import AuthProvider
|
||||
from fastmcp.server.auth.providers.jwt import JWTVerifier
|
||||
from mcp.server.auth.provider import AccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleWorkspaceAuthProvider(AuthProvider):
|
||||
"""
|
||||
Authentication provider for Google Workspace integration.
|
||||
|
||||
This provider implements the Remote Authentication pattern where:
|
||||
- Google acts as the Authorization Server (AS)
|
||||
- This MCP server acts as a Resource Server (RS)
|
||||
- Tokens are verified using Google's public keys
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Google Workspace auth provider."""
|
||||
super().__init__()
|
||||
|
||||
# Get configuration from environment
|
||||
self.client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
self.client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||
self.base_url = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
|
||||
self.port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)))
|
||||
|
||||
if not self.client_id:
|
||||
logger.warning("GOOGLE_OAUTH_CLIENT_ID not set - OAuth 2.1 authentication will not work")
|
||||
return
|
||||
|
||||
# Initialize JWT verifier for Google tokens
|
||||
self.jwt_verifier = JWTVerifier(
|
||||
jwks_uri="https://www.googleapis.com/oauth2/v3/certs",
|
||||
issuer="https://accounts.google.com",
|
||||
audience=self.client_id,
|
||||
algorithm="RS256"
|
||||
)
|
||||
|
||||
# Session store for bridging to Google credentials
|
||||
self._sessions: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async def verify_token(self, token: str) -> Optional[AccessToken]:
|
||||
"""
|
||||
Verify a bearer token issued by Google.
|
||||
|
||||
Args:
|
||||
token: The bearer token to verify
|
||||
|
||||
Returns:
|
||||
AccessToken object if valid, None otherwise
|
||||
"""
|
||||
if not self.client_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Use FastMCP's JWT verifier
|
||||
access_token = await self.jwt_verifier.verify_token(token)
|
||||
|
||||
if access_token:
|
||||
# Store session info for credential bridging
|
||||
session_id = f"google_{access_token.claims.get('sub', 'unknown')}"
|
||||
self._sessions[session_id] = {
|
||||
"access_token": token,
|
||||
"user_email": access_token.claims.get("email"),
|
||||
"claims": access_token.claims,
|
||||
"scopes": access_token.scopes or []
|
||||
}
|
||||
|
||||
logger.debug(f"Successfully verified Google token for user: {access_token.claims.get('email')}")
|
||||
|
||||
return access_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify Google token: {e}")
|
||||
return None
|
||||
|
||||
def customize_auth_routes(self, routes: List[Route]) -> List[Route]:
|
||||
"""
|
||||
Add custom routes for OAuth discovery and CORS proxy.
|
||||
|
||||
This implements:
|
||||
1. Protected resource metadata endpoint (RFC9728)
|
||||
2. Authorization server discovery proxy (to avoid CORS)
|
||||
3. Token exchange proxy (to avoid CORS)
|
||||
4. Client configuration endpoint
|
||||
"""
|
||||
|
||||
# Protected Resource Metadata endpoint
|
||||
async def protected_resource_metadata(request: Request):
|
||||
"""Return metadata about this protected resource."""
|
||||
metadata = {
|
||||
"resource": f"{self.base_url}:{self.port}",
|
||||
"authorization_servers": [
|
||||
# Point to the standard well-known endpoint
|
||||
f"{self.base_url}:{self.port}"
|
||||
],
|
||||
"bearer_methods_supported": ["header"],
|
||||
"scopes_supported": [
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/calendar",
|
||||
"https://www.googleapis.com/auth/drive",
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
"https://www.googleapis.com/auth/documents",
|
||||
"https://www.googleapis.com/auth/spreadsheets",
|
||||
"https://www.googleapis.com/auth/presentations",
|
||||
"https://www.googleapis.com/auth/chat.spaces",
|
||||
"https://www.googleapis.com/auth/forms",
|
||||
"https://www.googleapis.com/auth/tasks"
|
||||
],
|
||||
"resource_documentation": "https://developers.google.com/workspace",
|
||||
"client_registration_required": True,
|
||||
"client_configuration_endpoint": f"{self.base_url}:{self.port}/.well-known/oauth-client"
|
||||
}
|
||||
|
||||
return JSONResponse(
|
||||
content=metadata,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
routes.append(Route("/.well-known/oauth-protected-resource", protected_resource_metadata))
|
||||
|
||||
# OAuth authorization server metadata endpoint
|
||||
async def authorization_server_metadata(request: Request):
|
||||
"""Forward authorization server metadata from Google."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Try OpenID configuration first
|
||||
url = "https://accounts.google.com/.well-known/openid-configuration"
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
metadata = await response.json()
|
||||
|
||||
# Add OAuth 2.1 required fields
|
||||
metadata.setdefault("code_challenge_methods_supported", ["S256"])
|
||||
metadata.setdefault("pkce_required", True)
|
||||
|
||||
# Override token endpoint to use our proxy
|
||||
metadata["token_endpoint"] = f"{self.base_url}:{self.port}/oauth2/token"
|
||||
metadata["authorization_endpoint"] = f"{self.base_url}:{self.port}/oauth2/authorize"
|
||||
|
||||
return JSONResponse(
|
||||
content=metadata,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization"
|
||||
}
|
||||
)
|
||||
|
||||
# Fallback to default Google OAuth metadata
|
||||
return JSONResponse(
|
||||
content={
|
||||
"issuer": "https://accounts.google.com",
|
||||
"authorization_endpoint": f"{self.base_url}:{self.port}/oauth2/authorize",
|
||||
"token_endpoint": f"{self.base_url}:{self.port}/oauth2/token",
|
||||
"userinfo_endpoint": "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
"revocation_endpoint": "https://oauth2.googleapis.com/revoke",
|
||||
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
|
||||
"response_types_supported": ["code"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"pkce_required": True,
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"scopes_supported": ["openid", "email", "profile"],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"]
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Access-Control-Allow-Origin": "*"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching auth server metadata: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "Failed to fetch authorization server metadata"}
|
||||
)
|
||||
|
||||
routes.append(Route("/.well-known/oauth-authorization-server", authorization_server_metadata))
|
||||
|
||||
# Authorization server discovery proxy
|
||||
async def proxy_auth_server_discovery(request: Request):
|
||||
"""Proxy authorization server metadata to avoid CORS issues."""
|
||||
server_host = request.path_params.get("server_host", "accounts.google.com")
|
||||
|
||||
# Only allow known Google OAuth endpoints
|
||||
allowed_hosts = ["accounts.google.com", "oauth2.googleapis.com"]
|
||||
if server_host not in allowed_hosts:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "Invalid authorization server"}
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch metadata from Google
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Try OpenID configuration first
|
||||
url = f"https://{server_host}/.well-known/openid-configuration"
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
metadata = await response.json()
|
||||
|
||||
# Add OAuth 2.1 required fields
|
||||
metadata.setdefault("code_challenge_methods_supported", ["S256"])
|
||||
metadata.setdefault("pkce_required", True)
|
||||
|
||||
return JSONResponse(
|
||||
content=metadata,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization"
|
||||
}
|
||||
)
|
||||
|
||||
# Fallback to default Google OAuth metadata
|
||||
return JSONResponse(
|
||||
content={
|
||||
"issuer": f"https://{server_host}",
|
||||
"authorization_endpoint": f"https://{server_host}/o/oauth2/v2/auth",
|
||||
"token_endpoint": f"https://{server_host}/token",
|
||||
"userinfo_endpoint": "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
"revocation_endpoint": f"https://{server_host}/revoke",
|
||||
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
|
||||
"response_types_supported": ["code"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"pkce_required": True,
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"scopes_supported": ["openid", "email", "profile"],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"]
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Access-Control-Allow-Origin": "*"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error proxying auth server discovery: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "Failed to fetch authorization server metadata"}
|
||||
)
|
||||
|
||||
routes.append(Route("/auth/discovery/authorization-server/{server_host:path}", proxy_auth_server_discovery))
|
||||
|
||||
# Token exchange proxy endpoint
|
||||
async def proxy_token_exchange(request: Request):
|
||||
"""Proxy token exchange to Google to avoid CORS issues."""
|
||||
if request.method == "OPTIONS":
|
||||
return JSONResponse(
|
||||
content={},
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, Authorization"
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# Get form data
|
||||
body = await request.body()
|
||||
content_type = request.headers.get("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
# Determine which Google token endpoint to use
|
||||
token_endpoint = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# Forward request to Google
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {"Content-Type": content_type}
|
||||
|
||||
async with session.post(token_endpoint, data=body, headers=headers) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
# Log for debugging
|
||||
if response.status != 200:
|
||||
logger.error(f"Token exchange failed: {response.status} - {response_data}")
|
||||
else:
|
||||
logger.info("Token exchange successful")
|
||||
|
||||
# Store session for credential bridging
|
||||
if "access_token" in response_data:
|
||||
# Try to decode the token to get user info
|
||||
try:
|
||||
access_token = await self.verify_token(response_data["access_token"])
|
||||
if access_token:
|
||||
session_id = f"google_{access_token.claims.get('sub', 'unknown')}"
|
||||
self._sessions[session_id] = {
|
||||
"token_response": response_data,
|
||||
"user_email": access_token.claims.get("email"),
|
||||
"claims": access_token.claims
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not verify token for session storage: {e}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response.status,
|
||||
content=response_data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Cache-Control": "no-store"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in token proxy: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "server_error", "error_description": str(e)},
|
||||
headers={"Access-Control-Allow-Origin": "*"}
|
||||
)
|
||||
|
||||
routes.append(Route("/oauth2/token", proxy_token_exchange, methods=["POST", "OPTIONS"]))
|
||||
|
||||
# OAuth client configuration endpoint
|
||||
async def oauth_client_config(request: Request):
|
||||
"""Return OAuth client configuration for dynamic registration workaround."""
|
||||
if not self.client_id:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"error": "OAuth not configured"}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"client_id": self.client_id,
|
||||
"client_name": "Google Workspace MCP Server",
|
||||
"client_uri": f"{self.base_url}:{self.port}",
|
||||
"redirect_uris": [
|
||||
f"{self.base_url}:{self.port}/oauth2callback",
|
||||
"http://localhost:5173/auth/callback" # Common dev callback
|
||||
],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"scope": "openid email profile https://www.googleapis.com/auth/calendar https://www.googleapis.com/auth/drive https://www.googleapis.com/auth/gmail.modify",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"code_challenge_methods": ["S256"]
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Access-Control-Allow-Origin": "*"
|
||||
}
|
||||
)
|
||||
|
||||
routes.append(Route("/.well-known/oauth-client", oauth_client_config))
|
||||
|
||||
# OAuth authorization endpoint (redirect to Google)
|
||||
async def oauth_authorize(request: Request):
|
||||
"""Redirect to Google's authorization endpoint."""
|
||||
if request.method == "OPTIONS":
|
||||
return JSONResponse(
|
||||
content={},
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type"
|
||||
}
|
||||
)
|
||||
|
||||
# Get query parameters
|
||||
params = dict(request.query_params)
|
||||
|
||||
# Add our client ID if not provided
|
||||
if "client_id" not in params and self.client_id:
|
||||
params["client_id"] = self.client_id
|
||||
|
||||
# Ensure response_type is code
|
||||
params["response_type"] = "code"
|
||||
|
||||
# Build Google authorization URL
|
||||
google_auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(params)
|
||||
|
||||
# Return redirect
|
||||
return JSONResponse(
|
||||
status_code=302,
|
||||
headers={
|
||||
"Location": google_auth_url,
|
||||
"Access-Control-Allow-Origin": "*"
|
||||
}
|
||||
)
|
||||
|
||||
routes.append(Route("/oauth2/authorize", oauth_authorize, methods=["GET", "OPTIONS"]))
|
||||
|
||||
return routes
|
||||
|
||||
def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get session information for credential bridging.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
|
||||
Returns:
|
||||
Session information if found
|
||||
"""
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def create_session_from_token(self, token: str, user_email: str) -> str:
|
||||
"""
|
||||
Create a session from an access token for credential bridging.
|
||||
|
||||
Args:
|
||||
token: The access token
|
||||
user_email: The user's email address
|
||||
|
||||
Returns:
|
||||
Session ID
|
||||
"""
|
||||
session_id = f"google_{user_email}"
|
||||
self._sessions[session_id] = {
|
||||
"access_token": token,
|
||||
"user_email": user_email,
|
||||
"created_at": "now" # You could use datetime here
|
||||
}
|
||||
return session_id
|
||||
149
auth/mcp_oauth21_bridge.py
Normal file
149
auth/mcp_oauth21_bridge.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
MCP OAuth 2.1 Bridge
|
||||
|
||||
This module bridges MCP transport sessions with OAuth 2.1 authenticated sessions,
|
||||
allowing tool functions to access the OAuth 2.1 context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from auth.session_context import SessionContext, set_session_context
|
||||
# OAuth 2.1 is now handled by FastMCP auth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPOAuth21Bridge:
|
||||
"""
|
||||
Bridges MCP transport sessions with OAuth 2.1 sessions.
|
||||
|
||||
This class maintains a mapping between MCP transport session IDs
|
||||
and OAuth 2.1 authenticated sessions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Map MCP transport session ID to OAuth 2.1 session info
|
||||
self._mcp_to_oauth21_map: Dict[str, Dict[str, Any]] = {}
|
||||
# Map OAuth 2.1 session ID to MCP transport session ID
|
||||
self._oauth21_to_mcp_map: Dict[str, str] = {}
|
||||
|
||||
def link_sessions(
|
||||
self,
|
||||
mcp_session_id: str,
|
||||
oauth21_session_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
auth_context: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Link an MCP transport session with an OAuth 2.1 session.
|
||||
|
||||
Args:
|
||||
mcp_session_id: MCP transport session ID
|
||||
oauth21_session_id: OAuth 2.1 session ID
|
||||
user_id: User identifier
|
||||
auth_context: OAuth 2.1 authentication context
|
||||
"""
|
||||
session_info = {
|
||||
"oauth21_session_id": oauth21_session_id,
|
||||
"user_id": user_id,
|
||||
"auth_context": auth_context,
|
||||
"linked_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
self._mcp_to_oauth21_map[mcp_session_id] = session_info
|
||||
self._oauth21_to_mcp_map[oauth21_session_id] = mcp_session_id
|
||||
|
||||
logger.info(
|
||||
f"Linked MCP session {mcp_session_id} with OAuth 2.1 session {oauth21_session_id} "
|
||||
f"for user {user_id}"
|
||||
)
|
||||
|
||||
def get_oauth21_session(self, mcp_session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get OAuth 2.1 session info for an MCP transport session.
|
||||
|
||||
Args:
|
||||
mcp_session_id: MCP transport session ID
|
||||
|
||||
Returns:
|
||||
OAuth 2.1 session information if linked
|
||||
"""
|
||||
return self._mcp_to_oauth21_map.get(mcp_session_id)
|
||||
|
||||
def get_mcp_session(self, oauth21_session_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get MCP transport session ID for an OAuth 2.1 session.
|
||||
|
||||
Args:
|
||||
oauth21_session_id: OAuth 2.1 session ID
|
||||
|
||||
Returns:
|
||||
MCP transport session ID if linked
|
||||
"""
|
||||
return self._oauth21_to_mcp_map.get(oauth21_session_id)
|
||||
|
||||
def unlink_mcp_session(self, mcp_session_id: str):
|
||||
"""
|
||||
Remove the link for an MCP transport session.
|
||||
|
||||
Args:
|
||||
mcp_session_id: MCP transport session ID
|
||||
"""
|
||||
session_info = self._mcp_to_oauth21_map.pop(mcp_session_id, None)
|
||||
if session_info:
|
||||
oauth21_session_id = session_info.get("oauth21_session_id")
|
||||
if oauth21_session_id:
|
||||
self._oauth21_to_mcp_map.pop(oauth21_session_id, None)
|
||||
logger.info(f"Unlinked MCP session {mcp_session_id}")
|
||||
|
||||
def set_session_context_for_mcp(self, mcp_session_id: str) -> bool:
|
||||
"""
|
||||
Set the session context for the current request based on MCP session.
|
||||
|
||||
Args:
|
||||
mcp_session_id: MCP transport session ID
|
||||
|
||||
Returns:
|
||||
True if context was set, False otherwise
|
||||
"""
|
||||
session_info = self.get_oauth21_session(mcp_session_id)
|
||||
if not session_info:
|
||||
logger.debug(f"No OAuth 2.1 session linked to MCP session {mcp_session_id}")
|
||||
return False
|
||||
|
||||
# Create and set session context
|
||||
context = SessionContext(
|
||||
session_id=session_info["oauth21_session_id"],
|
||||
user_id=session_info["user_id"],
|
||||
auth_context=session_info["auth_context"],
|
||||
metadata={
|
||||
"mcp_session_id": mcp_session_id,
|
||||
"linked_at": session_info["linked_at"],
|
||||
}
|
||||
)
|
||||
|
||||
set_session_context(context)
|
||||
logger.debug(
|
||||
f"Set session context for MCP session {mcp_session_id}: "
|
||||
f"OAuth 2.1 session {context.session_id}, user {context.user_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get bridge statistics."""
|
||||
return {
|
||||
"linked_sessions": len(self._mcp_to_oauth21_map),
|
||||
"mcp_sessions": list(self._mcp_to_oauth21_map.keys()),
|
||||
"oauth21_sessions": list(self._oauth21_to_mcp_map.keys()),
|
||||
}
|
||||
|
||||
|
||||
# Global bridge instance
|
||||
_bridge = MCPOAuth21Bridge()
|
||||
|
||||
|
||||
def get_bridge() -> MCPOAuth21Bridge:
|
||||
"""Get the global MCP OAuth 2.1 bridge instance."""
|
||||
return _bridge
|
||||
78
auth/mcp_session_middleware.py
Normal file
78
auth/mcp_session_middleware.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
MCP Session Middleware
|
||||
|
||||
This middleware intercepts MCP requests and sets the session context
|
||||
for use by tool functions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable, Any
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
|
||||
from auth.session_context import (
|
||||
SessionContext,
|
||||
SessionContextManager,
|
||||
extract_session_from_headers,
|
||||
)
|
||||
# OAuth 2.1 is now handled by FastMCP auth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPSessionMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that extracts session information from requests and makes it
|
||||
available to MCP tool functions via context variables.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
"""Process request and set session context."""
|
||||
|
||||
logger.debug(f"MCPSessionMiddleware processing request: {request.method} {request.url.path}")
|
||||
|
||||
# Skip non-MCP paths
|
||||
if not request.url.path.startswith("/mcp"):
|
||||
logger.debug(f"Skipping non-MCP path: {request.url.path}")
|
||||
return await call_next(request)
|
||||
|
||||
session_context = None
|
||||
|
||||
try:
|
||||
# Extract session information
|
||||
headers = dict(request.headers)
|
||||
session_id = extract_session_from_headers(headers)
|
||||
|
||||
# Try to get OAuth 2.1 auth context
|
||||
auth_context = None
|
||||
if hasattr(request.state, "auth"):
|
||||
auth_context = request.state.auth
|
||||
|
||||
# Build session context
|
||||
if session_id or auth_context:
|
||||
session_context = SessionContext(
|
||||
session_id=session_id or (auth_context.session_id if auth_context else None),
|
||||
user_id=auth_context.user_id if auth_context else None,
|
||||
auth_context=auth_context,
|
||||
request=request,
|
||||
metadata={
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"MCP request with session: session_id={session_context.session_id}, "
|
||||
f"user_id={session_context.user_id}, path={request.url.path}"
|
||||
)
|
||||
|
||||
# Process request with session context
|
||||
with SessionContextManager(session_context):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in MCP session middleware: {e}")
|
||||
# Continue without session context
|
||||
return await call_next(request)
|
||||
141
auth/oauth21_google_bridge.py
Normal file
141
auth/oauth21_google_bridge.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Simplified OAuth 2.1 to Google Credentials Bridge
|
||||
|
||||
This module bridges FastMCP authentication to Google OAuth2 credentials
|
||||
for use with Google API clients.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from google.oauth2.credentials import Credentials
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global auth provider instance (set during server initialization)
|
||||
_auth_provider = None
|
||||
|
||||
|
||||
def set_auth_provider(provider):
|
||||
"""Set the global auth provider instance."""
|
||||
global _auth_provider
|
||||
_auth_provider = provider
|
||||
logger.info("OAuth 2.1 auth provider configured for Google credential bridging")
|
||||
|
||||
|
||||
def get_auth_provider():
|
||||
"""Get the global auth provider instance."""
|
||||
return _auth_provider
|
||||
|
||||
|
||||
def get_credentials_from_token(access_token: str, user_email: Optional[str] = None) -> Optional[Credentials]:
|
||||
"""
|
||||
Convert a bearer token to Google credentials.
|
||||
|
||||
Args:
|
||||
access_token: The bearer token
|
||||
user_email: Optional user email for session lookup
|
||||
|
||||
Returns:
|
||||
Google Credentials object or None
|
||||
"""
|
||||
if not _auth_provider:
|
||||
logger.error("Auth provider not configured")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Check if we have session info for this token
|
||||
session_info = None
|
||||
if user_email:
|
||||
session_id = f"google_{user_email}"
|
||||
session_info = _auth_provider.get_session_info(session_id)
|
||||
|
||||
# If we have a full token response (from token exchange), use it
|
||||
if session_info and "token_response" in session_info:
|
||||
token_data = session_info["token_response"]
|
||||
|
||||
# Calculate expiry
|
||||
expiry = None
|
||||
if "expires_in" in token_data:
|
||||
expiry = datetime.now(timezone.utc) + timedelta(seconds=token_data["expires_in"])
|
||||
|
||||
credentials = Credentials(
|
||||
token=token_data["access_token"],
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=_auth_provider.client_id,
|
||||
client_secret=_auth_provider.client_secret,
|
||||
scopes=token_data.get("scope", "").split() if token_data.get("scope") else None,
|
||||
expiry=expiry
|
||||
)
|
||||
|
||||
logger.debug(f"Created Google credentials from token response for {user_email}")
|
||||
return credentials
|
||||
|
||||
# Otherwise, create minimal credentials with just the access token
|
||||
else:
|
||||
# Assume token is valid for 1 hour (typical for Google tokens)
|
||||
expiry = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
credentials = Credentials(
|
||||
token=access_token,
|
||||
refresh_token=None,
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=_auth_provider.client_id,
|
||||
client_secret=_auth_provider.client_secret,
|
||||
scopes=None, # Will be populated from token claims if available
|
||||
expiry=expiry
|
||||
)
|
||||
|
||||
logger.debug("Created Google credentials from bearer token")
|
||||
return credentials
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Google credentials from token: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def store_token_session(token_response: dict, user_email: str) -> str:
|
||||
"""
|
||||
Store a token response in the session store.
|
||||
|
||||
Args:
|
||||
token_response: OAuth token response from Google
|
||||
user_email: User's email address
|
||||
|
||||
Returns:
|
||||
Session ID
|
||||
"""
|
||||
if not _auth_provider:
|
||||
logger.error("Auth provider not configured")
|
||||
return ""
|
||||
|
||||
try:
|
||||
session_id = f"google_{user_email}"
|
||||
_auth_provider._sessions[session_id] = {
|
||||
"token_response": token_response,
|
||||
"user_email": user_email,
|
||||
"created_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Also store in the global OAuth21 session store for compatibility
|
||||
session_store = get_oauth21_session_store()
|
||||
session_store.store_session(
|
||||
user_email=user_email,
|
||||
access_token=token_response.get("access_token"),
|
||||
refresh_token=token_response.get("refresh_token"),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=_auth_provider.client_id,
|
||||
client_secret=_auth_provider.client_secret,
|
||||
scopes=token_response.get("scope", "").split() if token_response.get("scope") else None,
|
||||
expiry=datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600)),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Stored token session for {user_email}")
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store token session: {e}")
|
||||
return ""
|
||||
208
auth/oauth21_integration.py
Normal file
208
auth/oauth21_integration.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
OAuth 2.1 Integration for Google Services
|
||||
|
||||
This module provides integration between FastMCP OAuth sessions and Google services,
|
||||
allowing authenticated sessions to be passed through to Google API calls.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Tuple, Any, Dict
|
||||
from functools import lru_cache
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
from auth.oauth21_google_bridge import get_auth_provider, get_credentials_from_token
|
||||
from auth.google_auth import (
|
||||
save_credentials_to_session,
|
||||
load_credentials_from_session,
|
||||
GoogleAuthenticationError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuth21GoogleServiceBuilder:
|
||||
"""Builds Google services using FastMCP OAuth authenticated sessions."""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the service builder.
|
||||
"""
|
||||
self._service_cache: Dict[str, Tuple[Any, str]] = {}
|
||||
|
||||
def extract_session_from_context(self, context: Optional[Dict[str, Any]] = None) -> Optional[str]:
|
||||
"""
|
||||
Extract session ID from various context sources.
|
||||
|
||||
Args:
|
||||
context: Context dictionary that may contain session information
|
||||
|
||||
Returns:
|
||||
Session ID if found, None otherwise
|
||||
"""
|
||||
if not context:
|
||||
return None
|
||||
|
||||
# Try to extract from OAuth 2.1 auth context
|
||||
if "auth_context" in context and hasattr(context["auth_context"], "session_id"):
|
||||
return context["auth_context"].session_id
|
||||
|
||||
# Try direct session_id
|
||||
if "session_id" in context:
|
||||
return context["session_id"]
|
||||
|
||||
# Try from request state
|
||||
if "request" in context:
|
||||
request = context["request"]
|
||||
if hasattr(request, "state") and hasattr(request.state, "auth"):
|
||||
auth_ctx = request.state.auth
|
||||
if hasattr(auth_ctx, "session_id"):
|
||||
return auth_ctx.session_id
|
||||
|
||||
return None
|
||||
|
||||
async def get_authenticated_service_with_session(
|
||||
self,
|
||||
service_name: str,
|
||||
version: str,
|
||||
tool_name: str,
|
||||
user_google_email: str,
|
||||
required_scopes: list[str],
|
||||
session_id: Optional[str] = None,
|
||||
auth_context: Optional[Any] = None,
|
||||
) -> Tuple[Any, str]:
|
||||
"""
|
||||
Get authenticated Google service using OAuth 2.1 session if available.
|
||||
|
||||
Args:
|
||||
service_name: Google service name (e.g., "gmail", "drive")
|
||||
version: API version (e.g., "v1", "v3")
|
||||
tool_name: Name of the tool for logging
|
||||
user_google_email: User's Google email
|
||||
required_scopes: Required OAuth scopes
|
||||
session_id: OAuth 2.1 session ID
|
||||
auth_context: OAuth 2.1 authentication context
|
||||
|
||||
Returns:
|
||||
Tuple of (service instance, actual user email)
|
||||
|
||||
Raises:
|
||||
GoogleAuthenticationError: If authentication fails
|
||||
"""
|
||||
cache_key = f"{user_google_email}:{service_name}:{version}:{':'.join(sorted(required_scopes))}"
|
||||
|
||||
# Check cache first
|
||||
if cache_key in self._service_cache:
|
||||
logger.debug(f"[{tool_name}] Using cached service for {user_google_email}")
|
||||
return self._service_cache[cache_key]
|
||||
|
||||
try:
|
||||
# First check the global OAuth 2.1 session store
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
store = get_oauth21_session_store()
|
||||
credentials = store.get_credentials(user_google_email)
|
||||
|
||||
if credentials and credentials.valid:
|
||||
logger.info(f"[{tool_name}] Found OAuth 2.1 credentials in global store for {user_google_email}")
|
||||
|
||||
# Build the service
|
||||
service = await asyncio.to_thread(
|
||||
build, service_name, version, credentials=credentials
|
||||
)
|
||||
|
||||
# Cache the service
|
||||
self._service_cache[cache_key] = (service, user_google_email)
|
||||
|
||||
return service, user_google_email
|
||||
|
||||
# OAuth 2.1 is now handled by FastMCP - removed legacy auth_layer code
|
||||
|
||||
# Fall back to legacy authentication
|
||||
logger.debug(f"[{tool_name}] Falling back to legacy authentication for {user_google_email}")
|
||||
from auth.google_auth import get_authenticated_google_service as legacy_get_service
|
||||
|
||||
return await legacy_get_service(
|
||||
service_name=service_name,
|
||||
version=version,
|
||||
tool_name=tool_name,
|
||||
user_google_email=user_google_email,
|
||||
required_scopes=required_scopes,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{tool_name}] Authentication failed for {user_google_email}: {e}")
|
||||
raise GoogleAuthenticationError(
|
||||
f"Failed to authenticate for {service_name}: {str(e)}"
|
||||
)
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the service cache."""
|
||||
self._service_cache.clear()
|
||||
logger.debug("Cleared OAuth 2.1 service cache")
|
||||
|
||||
|
||||
# Global instance
|
||||
_global_service_builder: Optional[OAuth21GoogleServiceBuilder] = None
|
||||
|
||||
|
||||
def get_oauth21_service_builder() -> OAuth21GoogleServiceBuilder:
|
||||
"""Get the global OAuth 2.1 service builder instance."""
|
||||
global _global_service_builder
|
||||
if _global_service_builder is None:
|
||||
_global_service_builder = OAuth21GoogleServiceBuilder()
|
||||
return _global_service_builder
|
||||
|
||||
|
||||
def set_auth_layer(auth_layer):
|
||||
"""
|
||||
Legacy compatibility function - no longer needed with FastMCP auth.
|
||||
"""
|
||||
logger.info("set_auth_layer called - OAuth is now handled by FastMCP")
|
||||
|
||||
|
||||
async def get_authenticated_google_service_oauth21(
|
||||
service_name: str,
|
||||
version: str,
|
||||
tool_name: str,
|
||||
user_google_email: str,
|
||||
required_scopes: list[str],
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, str]:
|
||||
"""
|
||||
Enhanced version of get_authenticated_google_service that supports OAuth 2.1.
|
||||
|
||||
This function checks for OAuth 2.1 session context and uses it if available,
|
||||
otherwise falls back to legacy authentication.
|
||||
|
||||
Args:
|
||||
service_name: Google service name
|
||||
version: API version
|
||||
tool_name: Tool name for logging
|
||||
user_google_email: User's Google email
|
||||
required_scopes: Required OAuth scopes
|
||||
context: Optional context containing session information
|
||||
|
||||
Returns:
|
||||
Tuple of (service instance, actual user email)
|
||||
"""
|
||||
builder = get_oauth21_service_builder()
|
||||
|
||||
# FastMCP handles context now - extract any session info
|
||||
session_id = None
|
||||
auth_context = None
|
||||
|
||||
if context:
|
||||
session_id = builder.extract_session_from_context(context)
|
||||
auth_context = context.get("auth_context")
|
||||
|
||||
return await builder.get_authenticated_service_with_session(
|
||||
service_name=service_name,
|
||||
version=version,
|
||||
tool_name=tool_name,
|
||||
user_google_email=user_google_email,
|
||||
required_scopes=required_scopes,
|
||||
session_id=session_id,
|
||||
auth_context=auth_context,
|
||||
)
|
||||
132
auth/oauth21_session_store.py
Normal file
132
auth/oauth21_session_store.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
OAuth 2.1 Session Store for Google Services
|
||||
|
||||
This module provides a global store for OAuth 2.1 authenticated sessions
|
||||
that can be accessed by Google service decorators.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional, Any
|
||||
from threading import RLock
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuth21SessionStore:
|
||||
"""
|
||||
Global store for OAuth 2.1 authenticated sessions.
|
||||
|
||||
This store maintains a mapping of user emails to their OAuth 2.1
|
||||
authenticated credentials, allowing Google services to access them.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._sessions: Dict[str, Dict[str, Any]] = {}
|
||||
self._lock = RLock()
|
||||
|
||||
def store_session(
|
||||
self,
|
||||
user_email: str,
|
||||
access_token: str,
|
||||
refresh_token: Optional[str] = None,
|
||||
token_uri: str = "https://oauth2.googleapis.com/token",
|
||||
client_id: Optional[str] = None,
|
||||
client_secret: Optional[str] = None,
|
||||
scopes: Optional[list] = None,
|
||||
expiry: Optional[Any] = None,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Store OAuth 2.1 session information.
|
||||
|
||||
Args:
|
||||
user_email: User's email address
|
||||
access_token: OAuth 2.1 access token
|
||||
refresh_token: OAuth 2.1 refresh token
|
||||
token_uri: Token endpoint URI
|
||||
client_id: OAuth client ID
|
||||
client_secret: OAuth client secret
|
||||
scopes: List of granted scopes
|
||||
expiry: Token expiry time
|
||||
session_id: OAuth 2.1 session ID
|
||||
"""
|
||||
with self._lock:
|
||||
session_info = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_uri": token_uri,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"scopes": scopes or [],
|
||||
"expiry": expiry,
|
||||
"session_id": session_id,
|
||||
}
|
||||
|
||||
self._sessions[user_email] = session_info
|
||||
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id})")
|
||||
|
||||
def get_credentials(self, user_email: str) -> Optional[Credentials]:
|
||||
"""
|
||||
Get Google credentials for a user from OAuth 2.1 session.
|
||||
|
||||
Args:
|
||||
user_email: User's email address
|
||||
|
||||
Returns:
|
||||
Google Credentials object or None
|
||||
"""
|
||||
with self._lock:
|
||||
session_info = self._sessions.get(user_email)
|
||||
if not session_info:
|
||||
logger.debug(f"No OAuth 2.1 session found for {user_email}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Create Google credentials from session info
|
||||
credentials = Credentials(
|
||||
token=session_info["access_token"],
|
||||
refresh_token=session_info.get("refresh_token"),
|
||||
token_uri=session_info["token_uri"],
|
||||
client_id=session_info.get("client_id"),
|
||||
client_secret=session_info.get("client_secret"),
|
||||
scopes=session_info.get("scopes", []),
|
||||
expiry=session_info.get("expiry"),
|
||||
)
|
||||
|
||||
logger.debug(f"Retrieved OAuth 2.1 credentials for {user_email}")
|
||||
return credentials
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create credentials for {user_email}: {e}")
|
||||
return None
|
||||
|
||||
def remove_session(self, user_email: str):
|
||||
"""Remove session for a user."""
|
||||
with self._lock:
|
||||
if user_email in self._sessions:
|
||||
del self._sessions[user_email]
|
||||
logger.info(f"Removed OAuth 2.1 session for {user_email}")
|
||||
|
||||
def has_session(self, user_email: str) -> bool:
|
||||
"""Check if a user has an active session."""
|
||||
with self._lock:
|
||||
return user_email in self._sessions
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get store statistics."""
|
||||
with self._lock:
|
||||
return {
|
||||
"total_sessions": len(self._sessions),
|
||||
"users": list(self._sessions.keys()),
|
||||
}
|
||||
|
||||
|
||||
# Global instance
|
||||
_global_store = OAuth21SessionStore()
|
||||
|
||||
|
||||
def get_oauth21_session_store() -> OAuth21SessionStore:
|
||||
"""Get the global OAuth 2.1 session store."""
|
||||
return _global_store
|
||||
@@ -7,15 +7,10 @@ from datetime import datetime, timedelta
|
||||
from google.auth.exceptions import RefreshError
|
||||
from auth.google_auth import get_authenticated_google_service, GoogleAuthenticationError
|
||||
|
||||
# Try to import OAuth 2.1 integration
|
||||
try:
|
||||
from auth.oauth21_integration import get_authenticated_google_service_oauth21
|
||||
from auth.session_context import get_session_context
|
||||
OAUTH21_INTEGRATION_AVAILABLE = True
|
||||
except ImportError:
|
||||
OAUTH21_INTEGRATION_AVAILABLE = False
|
||||
get_authenticated_google_service_oauth21 = None
|
||||
get_session_context = None
|
||||
# OAuth 2.1 integration is now handled by FastMCP auth
|
||||
OAUTH21_INTEGRATION_AVAILABLE = False
|
||||
get_authenticated_google_service_oauth21 = None
|
||||
get_session_context = None
|
||||
from auth.scopes import (
|
||||
GMAIL_READONLY_SCOPE, GMAIL_SEND_SCOPE, GMAIL_COMPOSE_SCOPE, GMAIL_MODIFY_SCOPE, GMAIL_LABELS_SCOPE,
|
||||
DRIVE_READONLY_SCOPE, DRIVE_FILE_SCOPE,
|
||||
|
||||
216
auth/service_decorator_oauth21.py.bak
Normal file
216
auth/service_decorator_oauth21.py.bak
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
Enhanced Service Decorator with OAuth 2.1 Support
|
||||
|
||||
This module provides an enhanced version of the service decorator that can
|
||||
extract and use OAuth 2.1 session context from FastMCP.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Dict, List, Optional, Any, Callable, Union
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from google.auth.exceptions import RefreshError
|
||||
|
||||
from auth.service_decorator import (
|
||||
SERVICE_CONFIGS,
|
||||
SCOPE_GROUPS,
|
||||
_resolve_scopes,
|
||||
_get_cache_key,
|
||||
_is_cache_valid,
|
||||
_handle_token_refresh_error,
|
||||
_get_cached_service,
|
||||
_cache_service,
|
||||
GoogleAuthenticationError,
|
||||
)
|
||||
from auth.oauth21_integration import get_authenticated_google_service_oauth21
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_context_from_args(args: tuple, kwargs: dict, sig: inspect.Signature) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract FastMCP Context from function arguments.
|
||||
|
||||
Args:
|
||||
args: Positional arguments
|
||||
kwargs: Keyword arguments
|
||||
sig: Function signature
|
||||
|
||||
Returns:
|
||||
Context information if found
|
||||
"""
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
# Check for Context type annotation
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.annotation and "Context" in str(param.annotation):
|
||||
# Found Context parameter
|
||||
if param_name in kwargs:
|
||||
ctx = kwargs[param_name]
|
||||
else:
|
||||
try:
|
||||
param_index = param_names.index(param_name)
|
||||
if param_index < len(args):
|
||||
ctx = args[param_index]
|
||||
else:
|
||||
continue
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Extract relevant information from Context
|
||||
if ctx:
|
||||
context_info = {}
|
||||
|
||||
# Try to get session_id
|
||||
if hasattr(ctx, "session_id"):
|
||||
context_info["session_id"] = ctx.session_id
|
||||
|
||||
# Try to get request object
|
||||
if hasattr(ctx, "request"):
|
||||
context_info["request"] = ctx.request
|
||||
|
||||
# Try to get auth context from request state
|
||||
if hasattr(ctx, "request") and hasattr(ctx.request, "state"):
|
||||
if hasattr(ctx.request.state, "auth"):
|
||||
context_info["auth_context"] = ctx.request.state.auth
|
||||
|
||||
return context_info if context_info else None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def require_google_service_oauth21(
|
||||
service_type: str,
|
||||
scopes: Union[str, List[str]],
|
||||
version: Optional[str] = None,
|
||||
cache_enabled: bool = True,
|
||||
fallback_to_legacy: bool = True
|
||||
):
|
||||
"""
|
||||
Enhanced decorator that injects authenticated Google service with OAuth 2.1 support.
|
||||
|
||||
This decorator checks for FastMCP Context in the function parameters and uses
|
||||
OAuth 2.1 session information if available, otherwise falls back to legacy auth.
|
||||
|
||||
Args:
|
||||
service_type: Type of Google service (e.g., 'gmail', 'drive')
|
||||
scopes: Required scopes or scope aliases
|
||||
version: API version (optional, uses default if not specified)
|
||||
cache_enabled: Whether to cache service instances
|
||||
fallback_to_legacy: Whether to fall back to legacy auth if OAuth 2.1 fails
|
||||
|
||||
Usage:
|
||||
@require_google_service_oauth21("gmail", "gmail_read")
|
||||
async def search_emails(service, user_google_email: str, ctx: Context):
|
||||
# service is automatically injected
|
||||
# ctx provides OAuth 2.1 session context
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
# Get service configuration
|
||||
if service_type not in SERVICE_CONFIGS:
|
||||
raise ValueError(f"Unknown service type: {service_type}")
|
||||
|
||||
service_config = SERVICE_CONFIGS[service_type]
|
||||
service_name = service_config["service"]
|
||||
service_version = version or service_config["version"]
|
||||
|
||||
# Resolve scopes
|
||||
resolved_scopes = _resolve_scopes(scopes)
|
||||
|
||||
# Create wrapper with modified signature
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.values())
|
||||
|
||||
# Remove 'service' parameter from signature
|
||||
wrapper_params = [p for p in params if p.name != 'service']
|
||||
wrapper_sig = sig.replace(parameters=wrapper_params)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract user_google_email
|
||||
user_google_email = None
|
||||
if 'user_google_email' in kwargs:
|
||||
user_google_email = kwargs['user_google_email']
|
||||
else:
|
||||
param_names = list(sig.parameters.keys())
|
||||
try:
|
||||
user_email_index = param_names.index('user_google_email')
|
||||
if user_email_index < len(args):
|
||||
user_google_email = args[user_email_index]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if not user_google_email:
|
||||
raise ValueError("user_google_email parameter is required")
|
||||
|
||||
# Extract context information
|
||||
context = _extract_context_from_args(args, kwargs, sig)
|
||||
|
||||
service = None
|
||||
actual_user_email = user_google_email
|
||||
|
||||
# Check cache if enabled
|
||||
if cache_enabled:
|
||||
cache_key = _get_cache_key(user_google_email, service_name, service_version, resolved_scopes)
|
||||
cached_result = _get_cached_service(cache_key)
|
||||
if cached_result:
|
||||
service, actual_user_email = cached_result
|
||||
logger.debug(f"Using cached service for {user_google_email}")
|
||||
|
||||
if service is None:
|
||||
try:
|
||||
tool_name = func.__name__
|
||||
|
||||
# Try OAuth 2.1 authentication with context
|
||||
if context:
|
||||
logger.debug(f"Attempting OAuth 2.1 authentication for {tool_name}")
|
||||
service, actual_user_email = await get_authenticated_google_service_oauth21(
|
||||
service_name=service_name,
|
||||
version=service_version,
|
||||
tool_name=tool_name,
|
||||
user_google_email=user_google_email,
|
||||
required_scopes=resolved_scopes,
|
||||
context=context,
|
||||
)
|
||||
elif fallback_to_legacy:
|
||||
# Fall back to legacy authentication
|
||||
logger.debug(f"Using legacy authentication for {tool_name}")
|
||||
from auth.google_auth import get_authenticated_google_service
|
||||
service, actual_user_email = await get_authenticated_google_service(
|
||||
service_name=service_name,
|
||||
version=service_version,
|
||||
tool_name=tool_name,
|
||||
user_google_email=user_google_email,
|
||||
required_scopes=resolved_scopes,
|
||||
)
|
||||
else:
|
||||
raise GoogleAuthenticationError(
|
||||
"OAuth 2.1 context required but not found"
|
||||
)
|
||||
|
||||
# Cache the service if enabled
|
||||
if cache_enabled and service:
|
||||
cache_key = _get_cache_key(user_google_email, service_name, service_version, resolved_scopes)
|
||||
_cache_service(cache_key, service, actual_user_email)
|
||||
|
||||
except GoogleAuthenticationError as e:
|
||||
raise Exception(str(e))
|
||||
|
||||
# Call the original function with the service object injected
|
||||
try:
|
||||
return await func(service, *args, **kwargs)
|
||||
except RefreshError as e:
|
||||
error_message = _handle_token_refresh_error(e, actual_user_email, service_name)
|
||||
raise Exception(error_message)
|
||||
|
||||
# Set the wrapper's signature to the one without 'service'
|
||||
wrapper.__signature__ = wrapper_sig
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
require_google_service = require_google_service_oauth21
|
||||
116
auth/session_context.py
Normal file
116
auth/session_context.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Session Context Management for OAuth 2.1 Integration
|
||||
|
||||
This module provides thread-local storage for OAuth 2.1 session context,
|
||||
allowing tool functions to access the current authenticated session.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variable to store the current session information
|
||||
_current_session_context: contextvars.ContextVar[Optional['SessionContext']] = contextvars.ContextVar(
|
||||
'current_session_context',
|
||||
default=None
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
"""Container for session-related information."""
|
||||
session_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
auth_context: Optional[Any] = None
|
||||
request: Optional[Any] = None
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
def set_session_context(context: Optional[SessionContext]):
|
||||
"""
|
||||
Set the current session context.
|
||||
|
||||
Args:
|
||||
context: The session context to set
|
||||
"""
|
||||
_current_session_context.set(context)
|
||||
if context:
|
||||
logger.debug(f"Set session context: session_id={context.session_id}, user_id={context.user_id}")
|
||||
else:
|
||||
logger.debug("Cleared session context")
|
||||
|
||||
|
||||
def get_session_context() -> Optional[SessionContext]:
|
||||
"""
|
||||
Get the current session context.
|
||||
|
||||
Returns:
|
||||
The current session context or None
|
||||
"""
|
||||
return _current_session_context.get()
|
||||
|
||||
|
||||
def clear_session_context():
|
||||
"""Clear the current session context."""
|
||||
set_session_context(None)
|
||||
|
||||
|
||||
class SessionContextManager:
|
||||
"""
|
||||
Context manager for temporarily setting session context.
|
||||
|
||||
Usage:
|
||||
with SessionContextManager(session_context):
|
||||
# Code that needs access to session context
|
||||
pass
|
||||
"""
|
||||
|
||||
def __init__(self, context: Optional[SessionContext]):
|
||||
self.context = context
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Set the session context."""
|
||||
self.token = _current_session_context.set(self.context)
|
||||
return self.context
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Reset the session context."""
|
||||
if self.token:
|
||||
_current_session_context.reset(self.token)
|
||||
|
||||
|
||||
def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Extract session ID from request headers.
|
||||
|
||||
Args:
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Session ID if found
|
||||
"""
|
||||
# Try different header names
|
||||
session_id = headers.get("mcp-session-id") or headers.get("Mcp-Session-Id")
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
session_id = headers.get("x-session-id") or headers.get("X-Session-ID")
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
# Try Authorization header for Bearer token
|
||||
auth_header = headers.get("authorization") or headers.get("Authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer "):
|
||||
# For now, we can't extract session from bearer token without the full context
|
||||
# This would need to be handled by the OAuth 2.1 middleware
|
||||
pass
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user