feat: initial commit from workspace-mcp
Some checks failed
Check Maintainer Edits Enabled / check-maintainer-edits (pull_request) Has been cancelled
Check Maintainer Edits Enabled / check-maintainer-edits-internal (pull_request) Has been cancelled
Docker Build and Push to GHCR / build-and-push (pull_request) Has been cancelled
Ruff / ruff (pull_request) Has been cancelled

This commit is contained in:
2026-03-17 19:23:33 -05:00
commit 395f0e2029
138 changed files with 41691 additions and 0 deletions

1
auth/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Make the auth directory a Python package

View File

@@ -0,0 +1,378 @@
"""
Authentication middleware to populate context state with user information
"""
import logging
import time
from fastmcp.server.middleware import Middleware, MiddlewareContext
from fastmcp.server.dependencies import get_access_token
from fastmcp.server.dependencies import get_http_headers
from auth.external_oauth_provider import get_session_time
from auth.oauth21_session_store import ensure_session_from_access_token
from auth.oauth_types import WorkspaceAccessToken
# Configure logging
logger = logging.getLogger(__name__)
class AuthInfoMiddleware(Middleware):
"""
Middleware to extract authentication information from JWT tokens
and populate the FastMCP context state for use in tools and prompts.
"""
def __init__(self):
super().__init__()
self.auth_provider_type = "GoogleProvider"
async def _process_request_for_auth(self, context: MiddlewareContext):
"""Helper to extract, verify, and store auth info from a request."""
if not context.fastmcp_context:
logger.warning("No fastmcp_context available")
return
authenticated_user = None
auth_via = None
# First check if FastMCP has already validated an access token
try:
access_token = get_access_token()
if access_token:
logger.info("[AuthInfoMiddleware] FastMCP access_token found")
user_email = getattr(access_token, "email", None)
if not user_email and hasattr(access_token, "claims"):
user_email = access_token.claims.get("email")
if user_email:
logger.info(
f"✓ Using FastMCP validated token for user: {user_email}"
)
await context.fastmcp_context.set_state(
"authenticated_user_email", user_email
)
await context.fastmcp_context.set_state(
"authenticated_via", "fastmcp_oauth"
)
await context.fastmcp_context.set_state(
"access_token", access_token, serializable=False
)
authenticated_user = user_email
auth_via = "fastmcp_oauth"
else:
logger.warning(
f"FastMCP access_token found but no email. Type: {type(access_token).__name__}"
)
except Exception as e:
logger.debug(f"Could not get FastMCP access_token: {e}")
# Try to get the HTTP request to extract Authorization header
if not authenticated_user:
try:
# Use the new FastMCP method to get HTTP headers
headers = get_http_headers()
logger.info(
f"[AuthInfoMiddleware] get_http_headers() returned: {headers is not None}, keys: {list(headers.keys()) if headers else 'None'}"
)
if headers:
logger.debug("Processing HTTP headers for authentication")
# Get the Authorization header
auth_header = headers.get("authorization", "")
if auth_header.startswith("Bearer "):
token_str = auth_header[7:] # Remove "Bearer " prefix
logger.info("Found Bearer token in request")
# For Google OAuth tokens (ya29.*), we need to verify them differently
if token_str.startswith("ya29."):
logger.debug("Detected Google OAuth access token format")
# Verify the token to get user info
from core.server import get_auth_provider
auth_provider = get_auth_provider()
if auth_provider:
try:
# Verify the token
verified_auth = await auth_provider.verify_token(
token_str
)
if verified_auth:
# Extract user email from verified token
user_email = getattr(
verified_auth, "email", None
)
if not user_email and hasattr(
verified_auth, "claims"
):
user_email = verified_auth.claims.get(
"email"
)
if isinstance(
verified_auth, WorkspaceAccessToken
):
# ExternalOAuthProvider returns a fully-formed WorkspaceAccessToken
access_token = verified_auth
else:
# Standard GoogleProvider returns a base AccessToken;
# wrap it in WorkspaceAccessToken for identical downstream handling
verified_expires = getattr(
verified_auth, "expires_at", None
)
access_token = WorkspaceAccessToken(
token=token_str,
client_id=getattr(
verified_auth, "client_id", None
)
or "google",
scopes=getattr(
verified_auth, "scopes", []
)
or [],
session_id=f"google_oauth_{token_str[:8]}",
expires_at=verified_expires
if verified_expires is not None
else int(time.time())
+ get_session_time(),
claims=getattr(
verified_auth, "claims", {}
)
or {},
sub=getattr(verified_auth, "sub", None)
or user_email,
email=user_email,
)
# Store in context state - this is the authoritative authentication state
await context.fastmcp_context.set_state(
"access_token",
access_token,
serializable=False,
)
mcp_session_id = getattr(
context.fastmcp_context, "session_id", None
)
ensure_session_from_access_token(
access_token,
user_email,
mcp_session_id,
)
await context.fastmcp_context.set_state(
"auth_provider_type",
self.auth_provider_type,
)
await context.fastmcp_context.set_state(
"token_type", "google_oauth"
)
await context.fastmcp_context.set_state(
"user_email", user_email
)
await context.fastmcp_context.set_state(
"username", user_email
)
# Set the definitive authentication state
await context.fastmcp_context.set_state(
"authenticated_user_email", user_email
)
await context.fastmcp_context.set_state(
"authenticated_via", "bearer_token"
)
authenticated_user = user_email
auth_via = "bearer_token"
else:
logger.error(
"Failed to verify Google OAuth token"
)
except Exception as e:
logger.error(
f"Error verifying Google OAuth token: {e}"
)
else:
logger.warning(
"No auth provider available to verify Google token"
)
else:
# Non-Google JWT tokens require verification
# SECURITY: Never set authenticated_user_email from unverified tokens
logger.debug(
"Unverified JWT token rejected - only verified tokens accepted"
)
else:
logger.debug("No Bearer token in Authorization header")
else:
logger.debug(
"No HTTP headers available (might be using stdio transport)"
)
except Exception as e:
logger.debug(f"Could not get HTTP request: {e}")
# After trying HTTP headers, check for other authentication methods
# This consolidates all authentication logic in the middleware
if not authenticated_user:
logger.debug(
"No authentication found via bearer token, checking other methods"
)
# Check transport mode
from core.config import get_transport_mode
transport_mode = get_transport_mode()
if transport_mode == "stdio":
# In stdio mode, check if there's a session with credentials
# This is ONLY safe in stdio mode because it's single-user
logger.debug("Checking for stdio mode authentication")
# Get the requested user from the context if available
requested_user = None
if hasattr(context, "request") and hasattr(context.request, "params"):
requested_user = context.request.params.get("user_google_email")
elif hasattr(context, "arguments"):
# FastMCP may store arguments differently
requested_user = context.arguments.get("user_google_email")
if requested_user:
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
# Check if user has a recent session
if store.has_session(requested_user):
logger.debug(
f"Using recent stdio session for {requested_user}"
)
# In stdio mode, we can trust the user has authenticated recently
await context.fastmcp_context.set_state(
"authenticated_user_email", requested_user
)
await context.fastmcp_context.set_state(
"authenticated_via", "stdio_session"
)
await context.fastmcp_context.set_state(
"auth_provider_type", "oauth21_stdio"
)
authenticated_user = requested_user
auth_via = "stdio_session"
except Exception as e:
logger.debug(f"Error checking stdio session: {e}")
# If no requested user was provided but exactly one session exists, assume it in stdio mode
if not authenticated_user:
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
single_user = store.get_single_user_email()
if single_user:
logger.debug(
f"Defaulting to single stdio OAuth session for {single_user}"
)
await context.fastmcp_context.set_state(
"authenticated_user_email", single_user
)
await context.fastmcp_context.set_state(
"authenticated_via", "stdio_single_session"
)
await context.fastmcp_context.set_state(
"auth_provider_type", "oauth21_stdio"
)
await context.fastmcp_context.set_state(
"user_email", single_user
)
await context.fastmcp_context.set_state(
"username", single_user
)
authenticated_user = single_user
auth_via = "stdio_single_session"
except Exception as e:
logger.debug(
f"Error determining stdio single-user session: {e}"
)
# Check for MCP session binding
if not authenticated_user and hasattr(
context.fastmcp_context, "session_id"
):
mcp_session_id = context.fastmcp_context.session_id
if mcp_session_id:
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
# Check if this MCP session is bound to a user
bound_user = store.get_user_by_mcp_session(mcp_session_id)
if bound_user:
logger.debug(f"MCP session bound to {bound_user}")
await context.fastmcp_context.set_state(
"authenticated_user_email", bound_user
)
await context.fastmcp_context.set_state(
"authenticated_via", "mcp_session_binding"
)
await context.fastmcp_context.set_state(
"auth_provider_type", "oauth21_session"
)
authenticated_user = bound_user
auth_via = "mcp_session_binding"
except Exception as e:
logger.debug(f"Error checking MCP session binding: {e}")
# Single exit point with logging
if authenticated_user:
logger.info(f"✓ Authenticated via {auth_via}: {authenticated_user}")
auth_email = await context.fastmcp_context.get_state(
"authenticated_user_email"
)
logger.debug(
f"Context state after auth: authenticated_user_email={auth_email}"
)
async def on_call_tool(self, context: MiddlewareContext, call_next):
"""Extract auth info from token and set in context state"""
logger.debug("Processing tool call authentication")
try:
await self._process_request_for_auth(context)
logger.debug("Passing to next handler")
result = await call_next(context)
logger.debug("Handler completed")
return result
except Exception as e:
# Check if this is an authentication error - don't log traceback for these
if "GoogleAuthenticationError" in str(
type(e)
) or "Access denied: Cannot retrieve credentials" in str(e):
logger.info(f"Authentication check failed: {e}")
else:
logger.error(f"Error in on_call_tool middleware: {e}", exc_info=True)
raise
async def on_get_prompt(self, context: MiddlewareContext, call_next):
"""Extract auth info for prompt requests too"""
logger.debug("Processing prompt authentication")
try:
await self._process_request_for_auth(context)
logger.debug("Passing prompt to next handler")
result = await call_next(context)
logger.debug("Prompt handler completed")
return result
except Exception as e:
# Check if this is an authentication error - don't log traceback for these
if "GoogleAuthenticationError" in str(
type(e)
) or "Access denied: Cannot retrieve credentials" in str(e):
logger.info(f"Authentication check failed in prompt: {e}")
else:
logger.error(f"Error in on_get_prompt middleware: {e}", exc_info=True)
raise

266
auth/credential_store.py Normal file
View File

@@ -0,0 +1,266 @@
"""
Credential Store API for Google Workspace MCP
This module provides a standardized interface for credential storage and retrieval,
supporting multiple backends configurable via environment variables.
"""
import os
import json
import logging
from abc import ABC, abstractmethod
from typing import Optional, List
from datetime import datetime
from google.oauth2.credentials import Credentials
logger = logging.getLogger(__name__)
class CredentialStore(ABC):
"""Abstract base class for credential storage."""
@abstractmethod
def get_credential(self, user_email: str) -> Optional[Credentials]:
"""
Get credentials for a user by email.
Args:
user_email: User's email address
Returns:
Google Credentials object or None if not found
"""
pass
@abstractmethod
def store_credential(self, user_email: str, credentials: Credentials) -> bool:
"""
Store credentials for a user.
Args:
user_email: User's email address
credentials: Google Credentials object to store
Returns:
True if successfully stored, False otherwise
"""
pass
@abstractmethod
def delete_credential(self, user_email: str) -> bool:
"""
Delete credentials for a user.
Args:
user_email: User's email address
Returns:
True if successfully deleted, False otherwise
"""
pass
@abstractmethod
def list_users(self) -> List[str]:
"""
List all users with stored credentials.
Returns:
List of user email addresses
"""
pass
class LocalDirectoryCredentialStore(CredentialStore):
"""Credential store that uses local JSON files for storage."""
def __init__(self, base_dir: Optional[str] = None):
"""
Initialize the local JSON credential store.
Args:
base_dir: Base directory for credential files. If None, uses the directory
configured by environment variables in this order:
1. WORKSPACE_MCP_CREDENTIALS_DIR (preferred)
2. GOOGLE_MCP_CREDENTIALS_DIR (backward compatibility)
3. ~/.google_workspace_mcp/credentials (default)
"""
if base_dir is None:
# Check WORKSPACE_MCP_CREDENTIALS_DIR first (preferred)
workspace_creds_dir = os.getenv("WORKSPACE_MCP_CREDENTIALS_DIR")
google_creds_dir = os.getenv("GOOGLE_MCP_CREDENTIALS_DIR")
if workspace_creds_dir:
base_dir = os.path.expanduser(workspace_creds_dir)
logger.info(
f"Using credentials directory from WORKSPACE_MCP_CREDENTIALS_DIR: {base_dir}"
)
# Fall back to GOOGLE_MCP_CREDENTIALS_DIR for backward compatibility
elif google_creds_dir:
base_dir = os.path.expanduser(google_creds_dir)
logger.info(
f"Using credentials directory from GOOGLE_MCP_CREDENTIALS_DIR: {base_dir}"
)
else:
home_dir = os.path.expanduser("~")
if home_dir and home_dir != "~":
base_dir = os.path.join(
home_dir, ".google_workspace_mcp", "credentials"
)
else:
base_dir = os.path.join(os.getcwd(), ".credentials")
logger.info(f"Using default credentials directory: {base_dir}")
self.base_dir = base_dir
logger.info(
f"LocalDirectoryCredentialStore initialized with base_dir: {base_dir}"
)
def _get_credential_path(self, user_email: str) -> str:
"""Get the file path for a user's credentials."""
if not os.path.exists(self.base_dir):
os.makedirs(self.base_dir)
logger.info(f"Created credentials directory: {self.base_dir}")
return os.path.join(self.base_dir, f"{user_email}.json")
def get_credential(self, user_email: str) -> Optional[Credentials]:
"""Get credentials from local JSON file."""
creds_path = self._get_credential_path(user_email)
if not os.path.exists(creds_path):
logger.debug(f"No credential file found for {user_email} at {creds_path}")
return None
try:
with open(creds_path, "r") as f:
creds_data = json.load(f)
# Parse expiry if present
expiry = None
if creds_data.get("expiry"):
try:
expiry = datetime.fromisoformat(creds_data["expiry"])
# Ensure timezone-naive datetime for Google auth library compatibility
if expiry.tzinfo is not None:
expiry = expiry.replace(tzinfo=None)
except (ValueError, TypeError) as e:
logger.warning(f"Could not parse expiry time for {user_email}: {e}")
credentials = Credentials(
token=creds_data.get("token"),
refresh_token=creds_data.get("refresh_token"),
token_uri=creds_data.get("token_uri"),
client_id=creds_data.get("client_id"),
client_secret=creds_data.get("client_secret"),
scopes=creds_data.get("scopes"),
expiry=expiry,
)
logger.debug(f"Loaded credentials for {user_email} from {creds_path}")
return credentials
except (IOError, json.JSONDecodeError, KeyError) as e:
logger.error(
f"Error loading credentials for {user_email} from {creds_path}: {e}"
)
return None
def store_credential(self, user_email: str, credentials: Credentials) -> bool:
"""Store credentials to local JSON file."""
creds_path = self._get_credential_path(user_email)
creds_data = {
"token": credentials.token,
"refresh_token": credentials.refresh_token,
"token_uri": credentials.token_uri,
"client_id": credentials.client_id,
"client_secret": credentials.client_secret,
"scopes": credentials.scopes,
"expiry": credentials.expiry.isoformat() if credentials.expiry else None,
}
try:
with open(creds_path, "w") as f:
json.dump(creds_data, f, indent=2)
logger.info(f"Stored credentials for {user_email} to {creds_path}")
return True
except IOError as e:
logger.error(
f"Error storing credentials for {user_email} to {creds_path}: {e}"
)
return False
def delete_credential(self, user_email: str) -> bool:
"""Delete credential file for a user."""
creds_path = self._get_credential_path(user_email)
try:
if os.path.exists(creds_path):
os.remove(creds_path)
logger.info(f"Deleted credentials for {user_email} from {creds_path}")
return True
else:
logger.debug(
f"No credential file to delete for {user_email} at {creds_path}"
)
return True # Consider it a success if file doesn't exist
except IOError as e:
logger.error(
f"Error deleting credentials for {user_email} from {creds_path}: {e}"
)
return False
def list_users(self) -> List[str]:
"""List all users with credential files."""
if not os.path.exists(self.base_dir):
return []
users = []
non_credential_files = {"oauth_states"}
try:
for filename in os.listdir(self.base_dir):
if filename.endswith(".json"):
user_email = filename[:-5] # Remove .json extension
if user_email in non_credential_files or "@" not in user_email:
continue
users.append(user_email)
logger.debug(
f"Found {len(users)} users with credentials in {self.base_dir}"
)
except OSError as e:
logger.error(f"Error listing credential files in {self.base_dir}: {e}")
return sorted(users)
# Global credential store instance
_credential_store: Optional[CredentialStore] = None
def get_credential_store() -> CredentialStore:
"""
Get the global credential store instance.
Returns:
Configured credential store instance
"""
global _credential_store
if _credential_store is None:
# always use LocalJsonCredentialStore as the default
# Future enhancement: support other backends via environment variables
_credential_store = LocalDirectoryCredentialStore()
logger.info(f"Initialized credential store: {type(_credential_store).__name__}")
return _credential_store
def set_credential_store(store: CredentialStore):
"""
Set the global credential store instance.
Args:
store: Credential store instance to use
"""
global _credential_store
_credential_store = store
logger.info(f"Set credential store: {type(store).__name__}")

View File

@@ -0,0 +1,188 @@
"""
External OAuth Provider for Google Workspace MCP
Extends FastMCP's GoogleProvider to support external OAuth flows where
access tokens (ya29.*) are issued by external systems and need validation.
This provider acts as a Resource Server only - it validates tokens issued by
Google's Authorization Server but does not issue tokens itself.
"""
import functools
import logging
import os
import time
from typing import Optional
from starlette.routing import Route
from fastmcp.server.auth.providers.google import GoogleProvider
from fastmcp.server.auth import AccessToken
from google.oauth2.credentials import Credentials
from auth.oauth_types import WorkspaceAccessToken
logger = logging.getLogger(__name__)
# Google's OAuth 2.0 Authorization Server
GOOGLE_ISSUER_URL = "https://accounts.google.com"
# Configurable session time in seconds (default: 1 hour, max: 24 hours)
_DEFAULT_SESSION_TIME = 3600
_MAX_SESSION_TIME = 86400
@functools.lru_cache(maxsize=1)
def get_session_time() -> int:
"""Parse SESSION_TIME from environment with fallback, min/max clamp.
Result is cached; changes require a server restart.
"""
raw = os.getenv("SESSION_TIME", "")
if not raw:
return _DEFAULT_SESSION_TIME
try:
value = int(raw)
except ValueError:
logger.warning(
"Invalid SESSION_TIME=%r, falling back to %d", raw, _DEFAULT_SESSION_TIME
)
return _DEFAULT_SESSION_TIME
clamped = max(1, min(value, _MAX_SESSION_TIME))
if clamped != value:
logger.warning(
"SESSION_TIME=%d clamped to %d (allowed range: 1%d)",
value,
clamped,
_MAX_SESSION_TIME,
)
return clamped
class ExternalOAuthProvider(GoogleProvider):
"""
Extended GoogleProvider that supports validating external Google OAuth access tokens.
This provider handles ya29.* access tokens by calling Google's userinfo API,
while maintaining compatibility with standard JWT ID tokens.
Unlike the standard GoogleProvider, this acts as a Resource Server only:
- Does NOT create /authorize, /token, /register endpoints
- Only advertises Google's authorization server in metadata
- Only validates tokens, does not issue them
"""
def __init__(
self,
client_id: str,
client_secret: str,
resource_server_url: Optional[str] = None,
**kwargs,
):
"""Initialize and store client credentials for token validation."""
self._resource_server_url = resource_server_url
super().__init__(client_id=client_id, client_secret=client_secret, **kwargs)
# Store credentials as they're not exposed by parent class
self._client_id = client_id
self._client_secret = client_secret
# Store as string - Pydantic validates it when passed to models
self.resource_server_url = self._resource_server_url
async def verify_token(self, token: str) -> Optional[AccessToken]:
"""
Verify a token - supports both JWT ID tokens and ya29.* access tokens.
For ya29.* access tokens (issued externally), validates by calling
Google's userinfo API. For JWT tokens, delegates to parent class.
Args:
token: Token string to verify (JWT or ya29.* access token)
Returns:
AccessToken object if valid, None otherwise
"""
# For ya29.* access tokens, validate using Google's userinfo API
if token.startswith("ya29."):
logger.debug("Validating external Google OAuth access token")
try:
from auth.google_auth import get_user_info
# Create minimal Credentials object for userinfo API call
credentials = Credentials(
token=token,
token_uri="https://oauth2.googleapis.com/token",
client_id=self._client_id,
client_secret=self._client_secret,
)
# Validate token by calling userinfo API
user_info = get_user_info(credentials, skip_valid_check=True)
if user_info and user_info.get("email"):
session_time = get_session_time()
# Token is valid - create AccessToken object
logger.info(
f"Validated external access token for: {user_info['email']}"
)
scope_list = list(getattr(self, "required_scopes", []) or [])
access_token = WorkspaceAccessToken(
token=token,
scopes=scope_list,
expires_at=int(time.time()) + session_time,
claims={
"email": user_info["email"],
"sub": user_info.get("id"),
},
client_id=self._client_id,
email=user_info["email"],
sub=user_info.get("id"),
)
return access_token
else:
logger.error("Could not get user info from access token")
return None
except Exception as e:
logger.error(f"Error validating external access token: {e}")
return None
# For JWT tokens, use parent class implementation
return await super().verify_token(token)
def get_routes(self, **kwargs) -> list[Route]:
"""
Get OAuth routes for external provider mode.
Returns only protected resource metadata routes that point to Google
as the authorization server. Does not create authorization server routes
(/authorize, /token, etc.) since tokens are issued by Google directly.
Args:
**kwargs: Additional arguments passed by FastMCP (e.g., mcp_path)
Returns:
List of routes - only protected resource metadata
"""
from mcp.server.auth.routes import create_protected_resource_routes
if not self.resource_server_url:
logger.warning(
"ExternalOAuthProvider: resource_server_url not set, no routes created"
)
return []
# Create protected resource routes that point to Google as the authorization server
# Pass strings directly - Pydantic validates them during model construction
protected_routes = create_protected_resource_routes(
resource_url=self.resource_server_url,
authorization_servers=[GOOGLE_ISSUER_URL],
scopes_supported=self.required_scopes,
resource_name="Google Workspace MCP",
resource_documentation=None,
)
logger.info(
f"ExternalOAuthProvider: Created protected resource routes pointing to {GOOGLE_ISSUER_URL}"
)
return protected_routes

1166
auth/google_auth.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,104 @@
"""
MCP Session Middleware
This middleware intercepts MCP requests and sets the session context
for use by tool functions.
"""
import logging
from typing import Callable, Any
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from auth.oauth21_session_store import (
SessionContext,
SessionContextManager,
extract_session_from_headers,
)
# OAuth 2.1 is now handled by FastMCP auth
logger = logging.getLogger(__name__)
class MCPSessionMiddleware(BaseHTTPMiddleware):
"""
Middleware that extracts session information from requests and makes it
available to MCP tool functions via context variables.
"""
async def dispatch(self, request: Request, call_next: Callable) -> Any:
"""Process request and set session context."""
logger.debug(
f"MCPSessionMiddleware processing request: {request.method} {request.url.path}"
)
# Skip non-MCP paths
if not request.url.path.startswith("/mcp"):
logger.debug(f"Skipping non-MCP path: {request.url.path}")
return await call_next(request)
session_context = None
try:
# Extract session information
headers = dict(request.headers)
session_id = extract_session_from_headers(headers)
# Try to get OAuth 2.1 auth context from FastMCP
auth_context = None
user_email = None
mcp_session_id = None
# Check for FastMCP auth context
if hasattr(request.state, "auth"):
auth_context = request.state.auth
# Extract user email from auth claims if available
if hasattr(auth_context, "claims") and auth_context.claims:
user_email = auth_context.claims.get("email")
# Check for FastMCP session ID (from streamable HTTP transport)
if hasattr(request.state, "session_id"):
mcp_session_id = request.state.session_id
logger.debug(f"Found FastMCP session ID: {mcp_session_id}")
# SECURITY: Do not decode JWT without verification
# User email must come from verified sources only (FastMCP auth context)
# Build session context
if session_id or auth_context or user_email or mcp_session_id:
# Create session ID hierarchy: explicit session_id > Google user session > FastMCP session
effective_session_id = session_id
if not effective_session_id and user_email:
effective_session_id = f"google_{user_email}"
elif not effective_session_id and mcp_session_id:
effective_session_id = mcp_session_id
session_context = SessionContext(
session_id=effective_session_id,
user_id=user_email
or (auth_context.user_id if auth_context else None),
auth_context=auth_context,
request=request,
metadata={
"path": request.url.path,
"method": request.method,
"user_email": user_email,
"mcp_session_id": mcp_session_id,
},
)
logger.debug(
f"MCP request with session: session_id={session_context.session_id}, "
f"user_id={session_context.user_id}, path={request.url.path}"
)
# Process request with session context
with SessionContextManager(session_context):
response = await call_next(request)
return response
except Exception as e:
logger.error(f"Error in MCP session middleware: {e}")
# Continue without session context
return await call_next(request)

View File

@@ -0,0 +1,989 @@
"""
OAuth 2.1 Session Store for Google Services
This module provides a global store for OAuth 2.1 authenticated sessions
that can be accessed by Google service decorators. It also includes
session context management and credential conversion functionality.
"""
import contextvars
import logging
from typing import Dict, Optional, Any, Tuple
from threading import RLock
from datetime import datetime, timedelta, timezone
from dataclasses import dataclass
from fastmcp.server.auth import AccessToken
from google.oauth2.credentials import Credentials
from auth.oauth_config import is_external_oauth21_provider
logger = logging.getLogger(__name__)
def _normalize_expiry_to_naive_utc(expiry: Optional[Any]) -> Optional[datetime]:
"""
Convert expiry values to timezone-naive UTC datetimes for google-auth compatibility.
Naive datetime inputs are assumed to already represent UTC and are returned unchanged so that
google-auth Credentials receive naive UTC datetimes for expiry comparison.
"""
if expiry is None:
return None
if isinstance(expiry, datetime):
if expiry.tzinfo is not None:
try:
return expiry.astimezone(timezone.utc).replace(tzinfo=None)
except Exception: # pragma: no cover - defensive
logger.debug(
"Failed to normalize aware expiry; returning without tzinfo"
)
return expiry.replace(tzinfo=None)
return expiry # Already naive; assumed to represent UTC
if isinstance(expiry, str):
try:
parsed = datetime.fromisoformat(expiry.replace("Z", "+00:00"))
except ValueError:
logger.debug("Failed to parse expiry string '%s'", expiry)
return None
return _normalize_expiry_to_naive_utc(parsed)
logger.debug("Unsupported expiry type '%s' (%s)", expiry, type(expiry))
return None
# Context variable to store the current session information
_current_session_context: contextvars.ContextVar[Optional["SessionContext"]] = (
contextvars.ContextVar("current_session_context", default=None)
)
@dataclass
class SessionContext:
"""Container for session-related information."""
session_id: Optional[str] = None
user_id: Optional[str] = None
auth_context: Optional[Any] = None
request: Optional[Any] = None
metadata: Dict[str, Any] = None
issuer: Optional[str] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
def set_session_context(context: Optional[SessionContext]):
"""
Set the current session context.
Args:
context: The session context to set
"""
_current_session_context.set(context)
if context:
logger.debug(
f"Set session context: session_id={context.session_id}, user_id={context.user_id}"
)
else:
logger.debug("Cleared session context")
def get_session_context() -> Optional[SessionContext]:
"""
Get the current session context.
Returns:
The current session context or None
"""
return _current_session_context.get()
def clear_session_context():
"""Clear the current session context."""
set_session_context(None)
class SessionContextManager:
"""
Context manager for temporarily setting session context.
Usage:
with SessionContextManager(session_context):
# Code that needs access to session context
pass
"""
def __init__(self, context: Optional[SessionContext]):
self.context = context
self.token = None
def __enter__(self):
"""Set the session context."""
self.token = _current_session_context.set(self.context)
return self.context
def __exit__(self, exc_type, exc_val, exc_tb):
"""Reset the session context."""
if self.token:
_current_session_context.reset(self.token)
def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
"""
Extract session ID from request headers.
Args:
headers: Request headers
Returns:
Session ID if found
"""
# Try different header names
session_id = headers.get("mcp-session-id") or headers.get("Mcp-Session-Id")
if session_id:
return session_id
session_id = headers.get("x-session-id") or headers.get("X-Session-ID")
if session_id:
return session_id
# Try Authorization header for Bearer token
auth_header = headers.get("authorization") or headers.get("Authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header[7:] # Remove "Bearer " prefix
# Intentionally ignore empty tokens - "Bearer " with no token should not
# create a session context (avoids hash collisions on empty string)
if token:
# Use thread-safe lookup to find session by access token
store = get_oauth21_session_store()
session_id = store.find_session_id_for_access_token(token)
if session_id:
return session_id
# If no session found, create a temporary session ID from token hash
# This allows header-based authentication to work with session context
import hashlib
token_hash = hashlib.sha256(token.encode()).hexdigest()[:8]
return f"bearer_token_{token_hash}"
return None
# =============================================================================
# OAuth21SessionStore - Main Session Management
# =============================================================================
class OAuth21SessionStore:
"""
Global store for OAuth 2.1 authenticated sessions.
This store maintains a mapping of user emails to their OAuth 2.1
authenticated credentials, allowing Google services to access them.
It also maintains a mapping from FastMCP session IDs to user emails.
Security: Sessions are bound to specific users and can only access
their own credentials.
"""
def __init__(self):
self._sessions: Dict[str, Dict[str, Any]] = {}
self._mcp_session_mapping: Dict[
str, str
] = {} # Maps FastMCP session ID -> user email
self._session_auth_binding: Dict[
str, str
] = {} # Maps session ID -> authenticated user email (immutable)
self._oauth_states: Dict[str, Dict[str, Any]] = {}
self._lock = RLock()
def _cleanup_expired_oauth_states_locked(self):
"""Remove expired OAuth state entries. Caller must hold lock."""
now = datetime.now(timezone.utc)
expired_states = [
state
for state, data in self._oauth_states.items()
if data.get("expires_at") and data["expires_at"] <= now
]
for state in expired_states:
del self._oauth_states[state]
logger.debug(
"Removed expired OAuth state: %s",
state[:8] if len(state) > 8 else state,
)
def store_oauth_state(
self,
state: str,
session_id: Optional[str] = None,
expires_in_seconds: int = 600,
code_verifier: Optional[str] = None,
) -> None:
"""Persist an OAuth state value for later validation."""
if not state:
raise ValueError("OAuth state must be provided")
if expires_in_seconds < 0:
raise ValueError("expires_in_seconds must be non-negative")
with self._lock:
self._cleanup_expired_oauth_states_locked()
now = datetime.now(timezone.utc)
expiry = now + timedelta(seconds=expires_in_seconds)
self._oauth_states[state] = {
"session_id": session_id,
"expires_at": expiry,
"created_at": now,
"code_verifier": code_verifier,
}
logger.debug(
"Stored OAuth state %s (expires at %s)",
state[:8] if len(state) > 8 else state,
expiry.isoformat(),
)
def validate_and_consume_oauth_state(
self,
state: str,
session_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Validate that a state value exists and consume it.
Args:
state: The OAuth state returned by Google.
session_id: Optional session identifier that initiated the flow.
Returns:
Metadata associated with the state.
Raises:
ValueError: If the state is missing, expired, or does not match the session.
"""
if not state:
raise ValueError("Missing OAuth state parameter")
with self._lock:
self._cleanup_expired_oauth_states_locked()
state_info = self._oauth_states.get(state)
if not state_info:
logger.error(
"SECURITY: OAuth callback received unknown or expired state"
)
raise ValueError("Invalid or expired OAuth state parameter")
bound_session = state_info.get("session_id")
if bound_session and session_id and bound_session != session_id:
# Consume the state to prevent replay attempts
del self._oauth_states[state]
logger.error(
"SECURITY: OAuth state session mismatch (expected %s, got %s)",
bound_session,
session_id,
)
raise ValueError("OAuth state does not match the initiating session")
# State is valid consume it to prevent reuse
del self._oauth_states[state]
logger.debug(
"Validated OAuth state %s",
state[:8] if len(state) > 8 else state,
)
return state_info
def store_session(
self,
user_email: str,
access_token: str,
refresh_token: Optional[str] = None,
token_uri: str = "https://oauth2.googleapis.com/token",
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
scopes: Optional[list] = None,
expiry: Optional[Any] = None,
session_id: Optional[str] = None,
mcp_session_id: Optional[str] = None,
issuer: Optional[str] = None,
):
"""
Store OAuth 2.1 session information.
Args:
user_email: User's email address
access_token: OAuth 2.1 access token
refresh_token: OAuth 2.1 refresh token
token_uri: Token endpoint URI
client_id: OAuth client ID
client_secret: OAuth client secret
scopes: List of granted scopes
expiry: Token expiry time
session_id: OAuth 2.1 session ID
mcp_session_id: FastMCP session ID to map to this user
issuer: Token issuer (e.g., "https://accounts.google.com")
"""
with self._lock:
normalized_expiry = _normalize_expiry_to_naive_utc(expiry)
# Clean up previous session mappings for this user before storing new one
old_session = self._sessions.get(user_email)
if old_session:
old_mcp_session_id = old_session.get("mcp_session_id")
old_session_id = old_session.get("session_id")
# Remove old MCP session mapping if it differs from new one
if old_mcp_session_id and old_mcp_session_id != mcp_session_id:
if old_mcp_session_id in self._mcp_session_mapping:
del self._mcp_session_mapping[old_mcp_session_id]
logger.debug(
f"Removed stale MCP session mapping: {old_mcp_session_id}"
)
if old_mcp_session_id in self._session_auth_binding:
del self._session_auth_binding[old_mcp_session_id]
logger.debug(
f"Removed stale auth binding: {old_mcp_session_id}"
)
# Remove old OAuth session binding if it differs from new one
if old_session_id and old_session_id != session_id:
if old_session_id in self._session_auth_binding:
del self._session_auth_binding[old_session_id]
logger.debug(
f"Removed stale OAuth session binding: {old_session_id}"
)
session_info = {
"access_token": access_token,
"refresh_token": refresh_token,
"token_uri": token_uri,
"client_id": client_id,
"client_secret": client_secret,
"scopes": scopes or [],
"expiry": normalized_expiry,
"session_id": session_id,
"mcp_session_id": mcp_session_id,
"issuer": issuer,
}
self._sessions[user_email] = session_info
# Store MCP session mapping if provided
if mcp_session_id:
# Create immutable session binding (first binding wins, cannot be changed)
if mcp_session_id not in self._session_auth_binding:
self._session_auth_binding[mcp_session_id] = user_email
logger.info(
f"Created immutable session binding: {mcp_session_id} -> {user_email}"
)
elif self._session_auth_binding[mcp_session_id] != user_email:
# Security: Attempt to bind session to different user
logger.error(
f"SECURITY: Attempt to rebind session {mcp_session_id} from {self._session_auth_binding[mcp_session_id]} to {user_email}"
)
raise ValueError(
f"Session {mcp_session_id} is already bound to a different user"
)
self._mcp_session_mapping[mcp_session_id] = user_email
logger.info(
f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id}, mcp_session_id: {mcp_session_id})"
)
else:
logger.info(
f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id})"
)
# Also create binding for the OAuth session ID
if session_id and session_id not in self._session_auth_binding:
self._session_auth_binding[session_id] = user_email
def get_credentials(self, user_email: str) -> Optional[Credentials]:
"""
Get Google credentials for a user from OAuth 2.1 session.
Args:
user_email: User's email address
Returns:
Google Credentials object or None
"""
with self._lock:
session_info = self._sessions.get(user_email)
if not session_info:
logger.debug(f"No OAuth 2.1 session found for {user_email}")
return None
try:
# Create Google credentials from session info
credentials = Credentials(
token=session_info["access_token"],
refresh_token=session_info.get("refresh_token"),
token_uri=session_info["token_uri"],
client_id=session_info.get("client_id"),
client_secret=session_info.get("client_secret"),
scopes=session_info.get("scopes", []),
expiry=session_info.get("expiry"),
)
logger.debug(f"Retrieved OAuth 2.1 credentials for {user_email}")
return credentials
except Exception as e:
logger.error(f"Failed to create credentials for {user_email}: {e}")
return None
def get_credentials_by_mcp_session(
self, mcp_session_id: str
) -> Optional[Credentials]:
"""
Get Google credentials using FastMCP session ID.
Args:
mcp_session_id: FastMCP session ID
Returns:
Google Credentials object or None
"""
with self._lock:
# Look up user email from MCP session mapping
user_email = self._mcp_session_mapping.get(mcp_session_id)
if not user_email:
logger.debug(f"No user mapping found for MCP session {mcp_session_id}")
return None
logger.debug(f"Found user {user_email} for MCP session {mcp_session_id}")
return self.get_credentials(user_email)
def get_credentials_with_validation(
self,
requested_user_email: str,
session_id: Optional[str] = None,
auth_token_email: Optional[str] = None,
allow_recent_auth: bool = False,
) -> Optional[Credentials]:
"""
Get Google credentials with session validation.
This method ensures that a session can only access credentials for its
authenticated user, preventing cross-account access.
Args:
requested_user_email: The email of the user whose credentials are requested
session_id: The current session ID (MCP or OAuth session)
auth_token_email: Email from the verified auth token (if available)
Returns:
Google Credentials object if validation passes, None otherwise
"""
with self._lock:
# Priority 1: Check auth token email (most secure, from verified JWT)
if auth_token_email:
if auth_token_email != requested_user_email:
logger.error(
f"SECURITY VIOLATION: Token for {auth_token_email} attempted to access "
f"credentials for {requested_user_email}"
)
return None
# Token email matches, allow access
return self.get_credentials(requested_user_email)
# Priority 2: Check session binding
if session_id:
bound_user = self._session_auth_binding.get(session_id)
if bound_user:
if bound_user != requested_user_email:
logger.error(
f"SECURITY VIOLATION: Session {session_id} (bound to {bound_user}) "
f"attempted to access credentials for {requested_user_email}"
)
return None
# Session binding matches, allow access
return self.get_credentials(requested_user_email)
# Check if this is an MCP session
mcp_user = self._mcp_session_mapping.get(session_id)
if mcp_user:
if mcp_user != requested_user_email:
logger.error(
f"SECURITY VIOLATION: MCP session {session_id} (user {mcp_user}) "
f"attempted to access credentials for {requested_user_email}"
)
return None
# MCP session matches, allow access
return self.get_credentials(requested_user_email)
# Special case: Allow access if user has recently authenticated (for clients that don't send tokens)
# CRITICAL SECURITY: This is ONLY allowed in stdio mode, NEVER in OAuth 2.1 mode
if allow_recent_auth and requested_user_email in self._sessions:
# Check transport mode to ensure this is only used in stdio
try:
from core.config import get_transport_mode
transport_mode = get_transport_mode()
if transport_mode != "stdio":
logger.error(
f"SECURITY: Attempted to use allow_recent_auth in {transport_mode} mode. "
f"This is only allowed in stdio mode!"
)
return None
except Exception as e:
logger.error(f"Failed to check transport mode: {e}")
return None
logger.info(
f"Allowing credential access for {requested_user_email} based on recent authentication "
f"(stdio mode only - client not sending bearer token)"
)
return self.get_credentials(requested_user_email)
# No session or token info available - deny access for security
logger.warning(
f"Credential access denied for {requested_user_email}: No valid session or token"
)
return None
def get_user_by_mcp_session(self, mcp_session_id: str) -> Optional[str]:
"""
Get user email by FastMCP session ID.
Args:
mcp_session_id: FastMCP session ID
Returns:
User email or None
"""
with self._lock:
return self._mcp_session_mapping.get(mcp_session_id)
def get_session_info(self, user_email: str) -> Optional[Dict[str, Any]]:
"""
Get complete session information including issuer.
Args:
user_email: User's email address
Returns:
Session information dictionary or None
"""
with self._lock:
return self._sessions.get(user_email)
def remove_session(self, user_email: str):
"""Remove session for a user."""
with self._lock:
if user_email in self._sessions:
# Get session IDs to clean up mappings
session_info = self._sessions.get(user_email, {})
mcp_session_id = session_info.get("mcp_session_id")
session_id = session_info.get("session_id")
# Remove from sessions
del self._sessions[user_email]
# Remove from MCP mapping if exists
if mcp_session_id and mcp_session_id in self._mcp_session_mapping:
del self._mcp_session_mapping[mcp_session_id]
# Also remove from auth binding
if mcp_session_id in self._session_auth_binding:
del self._session_auth_binding[mcp_session_id]
logger.info(
f"Removed OAuth 2.1 session for {user_email} and MCP mapping for {mcp_session_id}"
)
# Remove OAuth session binding if exists
if session_id and session_id in self._session_auth_binding:
del self._session_auth_binding[session_id]
if not mcp_session_id:
logger.info(f"Removed OAuth 2.1 session for {user_email}")
# Clean up any orphaned mappings that may have accumulated
self._cleanup_orphaned_mappings_locked()
def has_session(self, user_email: str) -> bool:
"""Check if a user has an active session."""
with self._lock:
return user_email in self._sessions
def has_mcp_session(self, mcp_session_id: str) -> bool:
"""Check if an MCP session has an associated user session."""
with self._lock:
return mcp_session_id in self._mcp_session_mapping
def get_single_user_email(self) -> Optional[str]:
"""Return the sole authenticated user email when exactly one session exists."""
with self._lock:
if len(self._sessions) == 1:
return next(iter(self._sessions))
return None
def get_stats(self) -> Dict[str, Any]:
"""Get store statistics."""
with self._lock:
return {
"total_sessions": len(self._sessions),
"users": list(self._sessions.keys()),
"mcp_session_mappings": len(self._mcp_session_mapping),
"mcp_sessions": list(self._mcp_session_mapping.keys()),
}
def find_session_id_for_access_token(self, token: str) -> Optional[str]:
"""
Thread-safe lookup of session ID by access token.
Args:
token: The access token to search for
Returns:
Session ID if found, None otherwise
"""
with self._lock:
for user_email, session_info in self._sessions.items():
if session_info.get("access_token") == token:
return session_info.get("session_id") or f"bearer_{user_email}"
return None
def _cleanup_orphaned_mappings_locked(self) -> int:
"""Remove orphaned mappings. Caller must hold lock."""
# Collect valid session IDs and mcp_session_ids from active sessions
valid_session_ids = set()
valid_mcp_session_ids = set()
for session_info in self._sessions.values():
if session_info.get("session_id"):
valid_session_ids.add(session_info["session_id"])
if session_info.get("mcp_session_id"):
valid_mcp_session_ids.add(session_info["mcp_session_id"])
removed = 0
# Remove orphaned MCP session mappings
orphaned_mcp = [
sid for sid in self._mcp_session_mapping if sid not in valid_mcp_session_ids
]
for sid in orphaned_mcp:
del self._mcp_session_mapping[sid]
removed += 1
logger.debug(f"Removed orphaned MCP session mapping: {sid}")
# Remove orphaned auth bindings
valid_bindings = valid_session_ids | valid_mcp_session_ids
orphaned_bindings = [
sid for sid in self._session_auth_binding if sid not in valid_bindings
]
for sid in orphaned_bindings:
del self._session_auth_binding[sid]
removed += 1
logger.debug(f"Removed orphaned auth binding: {sid}")
if removed > 0:
logger.info(f"Cleaned up {removed} orphaned session mappings/bindings")
return removed
def cleanup_orphaned_mappings(self) -> int:
"""
Remove orphaned entries from mcp_session_mapping and session_auth_binding.
Returns:
Number of orphaned entries removed
"""
with self._lock:
return self._cleanup_orphaned_mappings_locked()
# Global instance
_global_store = OAuth21SessionStore()
def get_oauth21_session_store() -> OAuth21SessionStore:
"""Get the global OAuth 2.1 session store."""
return _global_store
# =============================================================================
# Google Credentials Bridge (absorbed from oauth21_google_bridge.py)
# =============================================================================
# Global auth provider instance (set during server initialization)
_auth_provider = None
def set_auth_provider(provider):
"""Set the global auth provider instance."""
global _auth_provider
_auth_provider = provider
logger.debug("OAuth 2.1 session store configured")
def get_auth_provider():
"""Get the global auth provider instance."""
return _auth_provider
def _resolve_client_credentials() -> Tuple[Optional[str], Optional[str]]:
"""Resolve OAuth client credentials from the active provider or configuration."""
client_id: Optional[str] = None
client_secret: Optional[str] = None
if _auth_provider:
client_id = getattr(_auth_provider, "_upstream_client_id", None)
secret_obj = getattr(_auth_provider, "_upstream_client_secret", None)
if secret_obj is not None:
if hasattr(secret_obj, "get_secret_value"):
try:
client_secret = secret_obj.get_secret_value() # type: ignore[call-arg]
except Exception as exc: # pragma: no cover - defensive
logger.debug(
f"Failed to resolve client secret from provider: {exc}"
)
elif isinstance(secret_obj, str):
client_secret = secret_obj
if not client_id or not client_secret:
try:
from auth.oauth_config import get_oauth_config
cfg = get_oauth_config()
client_id = client_id or cfg.client_id
client_secret = client_secret or cfg.client_secret
except Exception as exc: # pragma: no cover - defensive
logger.debug(f"Failed to resolve client credentials from config: {exc}")
return client_id, client_secret
def _build_credentials_from_provider(
access_token: AccessToken,
) -> Optional[Credentials]:
"""Construct Google credentials from the provider cache."""
if not _auth_provider:
return None
access_entry = getattr(_auth_provider, "_access_tokens", {}).get(access_token.token)
if not access_entry:
access_entry = access_token
client_id, client_secret = _resolve_client_credentials()
refresh_token_value = getattr(_auth_provider, "_access_to_refresh", {}).get(
access_token.token
)
refresh_token_obj = None
if refresh_token_value:
refresh_token_obj = getattr(_auth_provider, "_refresh_tokens", {}).get(
refresh_token_value
)
expiry = None
expires_at = getattr(access_entry, "expires_at", None)
if expires_at:
try:
expiry_candidate = datetime.fromtimestamp(expires_at, tz=timezone.utc)
expiry = _normalize_expiry_to_naive_utc(expiry_candidate)
except Exception: # pragma: no cover - defensive
expiry = None
scopes = getattr(access_entry, "scopes", None)
return Credentials(
token=access_token.token,
refresh_token=refresh_token_obj.token if refresh_token_obj else None,
token_uri="https://oauth2.googleapis.com/token",
client_id=client_id,
client_secret=client_secret,
scopes=scopes,
expiry=expiry,
)
def ensure_session_from_access_token(
access_token: AccessToken,
user_email: Optional[str],
mcp_session_id: Optional[str] = None,
) -> Optional[Credentials]:
"""Ensure credentials derived from an access token are cached and returned."""
if not access_token:
return None
email = user_email
if not email and getattr(access_token, "claims", None):
email = access_token.claims.get("email")
credentials = _build_credentials_from_provider(access_token)
store_expiry: Optional[datetime] = None
if credentials is None:
client_id, client_secret = _resolve_client_credentials()
expiry = None
expires_at = getattr(access_token, "expires_at", None)
if expires_at:
try:
expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
except Exception: # pragma: no cover - defensive
expiry = None
normalized_expiry = _normalize_expiry_to_naive_utc(expiry)
credentials = Credentials(
token=access_token.token,
refresh_token=None,
token_uri="https://oauth2.googleapis.com/token",
client_id=client_id,
client_secret=client_secret,
scopes=getattr(access_token, "scopes", None),
expiry=normalized_expiry,
)
store_expiry = expiry
else:
store_expiry = credentials.expiry
# Skip session storage for external OAuth 2.1 to prevent memory leak from ephemeral tokens
if email and not is_external_oauth21_provider():
try:
store = get_oauth21_session_store()
store.store_session(
user_email=email,
access_token=credentials.token,
refresh_token=credentials.refresh_token,
token_uri=credentials.token_uri,
client_id=credentials.client_id,
client_secret=credentials.client_secret,
scopes=credentials.scopes,
expiry=store_expiry,
session_id=f"google_{email}",
mcp_session_id=mcp_session_id,
issuer="https://accounts.google.com",
)
except Exception as exc: # pragma: no cover - defensive
logger.debug(f"Failed to cache credentials for {email}: {exc}")
return credentials
def get_credentials_from_token(
access_token: str, user_email: Optional[str] = None
) -> Optional[Credentials]:
"""
Convert a bearer token to Google credentials.
Args:
access_token: The bearer token
user_email: Optional user email for session lookup
Returns:
Google Credentials object or None
"""
try:
store = get_oauth21_session_store()
# If we have user_email, try to get credentials from store
if user_email:
credentials = store.get_credentials(user_email)
if credentials and credentials.token == access_token:
logger.debug(f"Found matching credentials from store for {user_email}")
return credentials
# If the FastMCP provider is managing tokens, sync from provider storage
if _auth_provider:
access_record = getattr(_auth_provider, "_access_tokens", {}).get(
access_token
)
if access_record:
logger.debug("Building credentials from FastMCP provider cache")
return ensure_session_from_access_token(access_record, user_email)
# Otherwise, create minimal credentials with just the access token
# Assume token is valid for 1 hour (typical for Google tokens)
expiry = _normalize_expiry_to_naive_utc(
datetime.now(timezone.utc) + timedelta(hours=1)
)
client_id, client_secret = _resolve_client_credentials()
credentials = Credentials(
token=access_token,
refresh_token=None,
token_uri="https://oauth2.googleapis.com/token",
client_id=client_id,
client_secret=client_secret,
scopes=None,
expiry=expiry,
)
logger.debug("Created fallback Google credentials from bearer token")
return credentials
except Exception as e:
logger.error(f"Failed to create Google credentials from token: {e}")
return None
def store_token_session(
token_response: dict, user_email: str, mcp_session_id: Optional[str] = None
) -> str:
"""
Store a token response in the session store.
Args:
token_response: OAuth token response from Google
user_email: User's email address
mcp_session_id: Optional FastMCP session ID to map to this user
Returns:
Session ID
"""
if not _auth_provider:
logger.error("Auth provider not configured")
return ""
try:
# Try to get FastMCP session ID from context if not provided
if not mcp_session_id:
try:
from core.context import get_fastmcp_session_id
mcp_session_id = get_fastmcp_session_id()
if mcp_session_id:
logger.debug(
f"Got FastMCP session ID from context: {mcp_session_id}"
)
except Exception as e:
logger.debug(f"Could not get FastMCP session from context: {e}")
# Store session in OAuth21SessionStore
store = get_oauth21_session_store()
session_id = f"google_{user_email}"
client_id, client_secret = _resolve_client_credentials()
scopes = token_response.get("scope", "")
scopes_list = scopes.split() if scopes else None
expiry = datetime.now(timezone.utc) + timedelta(
seconds=token_response.get("expires_in", 3600)
)
store.store_session(
user_email=user_email,
access_token=token_response.get("access_token"),
refresh_token=token_response.get("refresh_token"),
token_uri="https://oauth2.googleapis.com/token",
client_id=client_id,
client_secret=client_secret,
scopes=scopes_list,
expiry=expiry,
session_id=session_id,
mcp_session_id=mcp_session_id,
issuer="https://accounts.google.com",
)
if mcp_session_id:
logger.info(
f"Stored token session for {user_email} with MCP session {mcp_session_id}"
)
else:
logger.info(f"Stored token session for {user_email}")
return session_id
except Exception as e:
logger.error(f"Failed to store token session: {e}")
return ""

View File

@@ -0,0 +1,287 @@
"""
Transport-aware OAuth callback handling.
In streamable-http mode: Uses the existing FastAPI server
In stdio mode: Starts a minimal HTTP server just for OAuth callbacks
"""
import asyncio
import logging
import threading
import time
import socket
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import FileResponse, JSONResponse
from typing import Optional
from urllib.parse import urlparse
from auth.scopes import SCOPES, get_current_scopes # noqa
from auth.oauth_responses import (
create_error_response,
create_success_response,
create_server_error_response,
)
from auth.google_auth import handle_auth_callback, check_client_secrets
from auth.oauth_config import get_oauth_redirect_uri
logger = logging.getLogger(__name__)
class MinimalOAuthServer:
"""
Minimal HTTP server for OAuth callbacks in stdio mode.
Only starts when needed and uses the same port (8000) as streamable-http mode.
"""
def __init__(self, port: int = 8000, base_uri: str = "http://localhost"):
self.port = port
self.base_uri = base_uri
self.app = FastAPI()
self.server = None
self.server_thread = None
self.is_running = False
# Setup the callback route
self._setup_callback_route()
# Setup attachment serving route
self._setup_attachment_route()
def _setup_callback_route(self):
"""Setup the OAuth callback route."""
@self.app.get("/oauth2callback")
async def oauth_callback(request: Request):
"""Handle OAuth callback - same logic as in core/server.py"""
code = request.query_params.get("code")
error = request.query_params.get("error")
if error:
error_message = (
f"Authentication failed: Google returned an error: {error}."
)
logger.error(error_message)
return create_error_response(error_message)
if not code:
error_message = (
"Authentication failed: No authorization code received from Google."
)
logger.error(error_message)
return create_error_response(error_message)
try:
# Check if we have credentials available (environment variables or file)
error_message = check_client_secrets()
if error_message:
return create_server_error_response(error_message)
logger.info(
"OAuth callback: Received authorization code. Attempting to exchange for tokens."
)
# Session ID tracking removed - not needed
# Exchange code for credentials
redirect_uri = get_oauth_redirect_uri()
verified_user_id, credentials = handle_auth_callback(
scopes=get_current_scopes(),
authorization_response=str(request.url),
redirect_uri=redirect_uri,
session_id=None,
)
logger.info(
f"OAuth callback: Successfully authenticated user: {verified_user_id}."
)
# Return success page using shared template
return create_success_response(verified_user_id)
except Exception as e:
error_message_detail = f"Error processing OAuth callback: {str(e)}"
logger.error(error_message_detail, exc_info=True)
return create_server_error_response(str(e))
def _setup_attachment_route(self):
"""Setup the attachment serving route."""
from core.attachment_storage import get_attachment_storage
@self.app.get("/attachments/{file_id}")
async def serve_attachment(file_id: str, request: Request):
"""Serve a stored attachment file."""
storage = get_attachment_storage()
metadata = storage.get_attachment_metadata(file_id)
if not metadata:
return JSONResponse(
{"error": "Attachment not found or expired"}, status_code=404
)
file_path = storage.get_attachment_path(file_id)
if not file_path:
return JSONResponse(
{"error": "Attachment file not found"}, status_code=404
)
return FileResponse(
path=str(file_path),
filename=metadata["filename"],
media_type=metadata["mime_type"],
)
def start(self) -> tuple[bool, str]:
"""
Start the minimal OAuth server.
Returns:
Tuple of (success: bool, error_message: str)
"""
if self.is_running:
logger.info("Minimal OAuth server is already running")
return True, ""
# Check if port is available
# Extract hostname from base_uri (e.g., "http://localhost" -> "localhost")
try:
parsed_uri = urlparse(self.base_uri)
hostname = parsed_uri.hostname or "localhost"
except Exception:
hostname = "localhost"
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((hostname, self.port))
except OSError:
error_msg = f"Port {self.port} is already in use on {hostname}. Cannot start minimal OAuth server."
logger.error(error_msg)
return False, error_msg
def run_server():
"""Run the server in a separate thread."""
try:
config = uvicorn.Config(
self.app,
host=hostname,
port=self.port,
log_level="warning",
access_log=False,
)
self.server = uvicorn.Server(config)
asyncio.run(self.server.serve())
except Exception as e:
logger.error(f"Minimal OAuth server error: {e}", exc_info=True)
self.is_running = False
# Start server in background thread
self.server_thread = threading.Thread(target=run_server, daemon=True)
self.server_thread.start()
# Wait for server to start
max_wait = 3.0
start_time = time.time()
while time.time() - start_time < max_wait:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
result = s.connect_ex((hostname, self.port))
if result == 0:
self.is_running = True
logger.info(
f"Minimal OAuth server started on {hostname}:{self.port}"
)
return True, ""
except Exception:
pass
time.sleep(0.1)
error_msg = f"Failed to start minimal OAuth server on {hostname}:{self.port} - server did not respond within {max_wait}s"
logger.error(error_msg)
return False, error_msg
def stop(self):
"""Stop the minimal OAuth server."""
if not self.is_running:
return
try:
if self.server:
if hasattr(self.server, "should_exit"):
self.server.should_exit = True
if self.server_thread and self.server_thread.is_alive():
self.server_thread.join(timeout=3.0)
self.is_running = False
logger.info("Minimal OAuth server stopped")
except Exception as e:
logger.error(f"Error stopping minimal OAuth server: {e}", exc_info=True)
# Global instance for stdio mode
_minimal_oauth_server: Optional[MinimalOAuthServer] = None
def ensure_oauth_callback_available(
transport_mode: str = "stdio", port: int = 8000, base_uri: str = "http://localhost"
) -> tuple[bool, str]:
"""
Ensure OAuth callback endpoint is available for the given transport mode.
For streamable-http: Assumes the main server is already running
For stdio: Starts a minimal server if needed
Args:
transport_mode: "stdio" or "streamable-http"
port: Port number (default 8000)
base_uri: Base URI (default "http://localhost")
Returns:
Tuple of (success: bool, error_message: str)
"""
global _minimal_oauth_server
if transport_mode == "streamable-http":
# In streamable-http mode, the main FastAPI server should handle callbacks
logger.debug(
"Using existing FastAPI server for OAuth callbacks (streamable-http mode)"
)
return True, ""
elif transport_mode == "stdio":
# In stdio mode, start minimal server if not already running
if _minimal_oauth_server is None:
logger.info(f"Creating minimal OAuth server instance for {base_uri}:{port}")
_minimal_oauth_server = MinimalOAuthServer(port, base_uri)
if not _minimal_oauth_server.is_running:
logger.info("Starting minimal OAuth server for stdio mode")
success, error_msg = _minimal_oauth_server.start()
if success:
logger.info(
f"Minimal OAuth server successfully started on {base_uri}:{port}"
)
return True, ""
else:
logger.error(
f"Failed to start minimal OAuth server on {base_uri}:{port}: {error_msg}"
)
return False, error_msg
else:
logger.info("Minimal OAuth server is already running")
return True, ""
else:
error_msg = f"Unknown transport mode: {transport_mode}"
logger.error(error_msg)
return False, error_msg
def cleanup_oauth_callback_server():
"""Clean up the minimal OAuth server if it was started."""
global _minimal_oauth_server
if _minimal_oauth_server:
_minimal_oauth_server.stop()
_minimal_oauth_server = None

444
auth/oauth_config.py Normal file
View File

@@ -0,0 +1,444 @@
"""
OAuth Configuration Management
This module centralizes OAuth-related configuration to eliminate hardcoded values
scattered throughout the codebase. It provides environment variable support and
sensible defaults for all OAuth-related settings.
Supports both OAuth 2.0 and OAuth 2.1 with automatic client capability detection.
"""
import os
from threading import RLock
from urllib.parse import urlparse
from typing import List, Optional, Dict, Any
class OAuthConfig:
"""
Centralized OAuth configuration management.
This class eliminates the hardcoded configuration anti-pattern identified
in the challenge review by providing a single source of truth for all
OAuth-related configuration values.
"""
def __init__(self):
# Base server configuration
self.base_uri = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
self.port = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", "8000")))
self.base_url = f"{self.base_uri}:{self.port}"
# External URL for reverse proxy scenarios
self.external_url = os.getenv("WORKSPACE_EXTERNAL_URL")
# OAuth client configuration
self.client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
self.client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
# OAuth 2.1 configuration
self.oauth21_enabled = (
os.getenv("MCP_ENABLE_OAUTH21", "false").lower() == "true"
)
self.pkce_required = self.oauth21_enabled # PKCE is mandatory in OAuth 2.1
self.supported_code_challenge_methods = (
["S256", "plain"] if not self.oauth21_enabled else ["S256"]
)
# External OAuth 2.1 provider configuration
self.external_oauth21_provider = (
os.getenv("EXTERNAL_OAUTH21_PROVIDER", "false").lower() == "true"
)
if self.external_oauth21_provider and not self.oauth21_enabled:
raise ValueError(
"EXTERNAL_OAUTH21_PROVIDER requires MCP_ENABLE_OAUTH21=true"
)
# Stateless mode configuration
self.stateless_mode = (
os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
)
if self.stateless_mode and not self.oauth21_enabled:
raise ValueError(
"WORKSPACE_MCP_STATELESS_MODE requires MCP_ENABLE_OAUTH21=true"
)
# Transport mode (will be set at runtime)
self._transport_mode = "stdio" # Default
# Redirect URI configuration
self.redirect_uri = self._get_redirect_uri()
self.redirect_path = self._get_redirect_path(self.redirect_uri)
# Ensure FastMCP's Google provider picks up our existing configuration
self._apply_fastmcp_google_env()
def _get_redirect_uri(self) -> str:
"""
Get the OAuth redirect URI, supporting reverse proxy configurations.
Returns:
The configured redirect URI
"""
explicit_uri = os.getenv("GOOGLE_OAUTH_REDIRECT_URI")
if explicit_uri:
return explicit_uri
return f"{self.base_url}/oauth2callback"
@staticmethod
def _get_redirect_path(uri: str) -> str:
"""Extract the redirect path from a full redirect URI."""
parsed = urlparse(uri)
if parsed.scheme or parsed.netloc:
path = parsed.path or "/oauth2callback"
else:
# If the value was already a path, ensure it starts with '/'
path = uri if uri.startswith("/") else f"/{uri}"
return path or "/oauth2callback"
def _apply_fastmcp_google_env(self) -> None:
"""Mirror legacy GOOGLE_* env vars into FastMCP Google provider settings."""
if not self.client_id:
return
def _set_if_absent(key: str, value: Optional[str]) -> None:
if value and key not in os.environ:
os.environ[key] = value
# Don't set FASTMCP_SERVER_AUTH if using external OAuth provider
# (external OAuth means protocol-level auth is disabled, only tool-level auth)
if not self.external_oauth21_provider:
_set_if_absent(
"FASTMCP_SERVER_AUTH",
"fastmcp.server.auth.providers.google.GoogleProvider"
if self.oauth21_enabled
else None,
)
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_CLIENT_ID", self.client_id)
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_CLIENT_SECRET", self.client_secret)
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_BASE_URL", self.get_oauth_base_url())
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_REDIRECT_PATH", self.redirect_path)
def get_redirect_uris(self) -> List[str]:
"""
Get all valid OAuth redirect URIs.
Returns:
List of all supported redirect URIs
"""
uris = []
# Primary redirect URI
uris.append(self.redirect_uri)
# Custom redirect URIs from environment
custom_uris = os.getenv("OAUTH_CUSTOM_REDIRECT_URIS")
if custom_uris:
uris.extend([uri.strip() for uri in custom_uris.split(",")])
# Remove duplicates while preserving order
return list(dict.fromkeys(uris))
def get_allowed_origins(self) -> List[str]:
"""
Get allowed CORS origins for OAuth endpoints.
Returns:
List of allowed origins for CORS
"""
origins = []
# Server's own origin
origins.append(self.base_url)
# VS Code and development origins
origins.extend(
[
"vscode-webview://",
"https://vscode.dev",
"https://github.dev",
]
)
# Custom origins from environment
custom_origins = os.getenv("OAUTH_ALLOWED_ORIGINS")
if custom_origins:
origins.extend([origin.strip() for origin in custom_origins.split(",")])
return list(dict.fromkeys(origins))
def is_configured(self) -> bool:
"""
Check if OAuth is properly configured.
Returns:
True if OAuth client credentials are available
"""
return bool(self.client_id and self.client_secret)
def get_oauth_base_url(self) -> str:
"""
Get OAuth base URL for constructing OAuth endpoints.
Uses WORKSPACE_EXTERNAL_URL if set (for reverse proxy scenarios),
otherwise falls back to constructed base_url with port.
Returns:
Base URL for OAuth endpoints
"""
if self.external_url:
return self.external_url
return self.base_url
def validate_redirect_uri(self, uri: str) -> bool:
"""
Validate if a redirect URI is allowed.
Args:
uri: The redirect URI to validate
Returns:
True if the URI is allowed, False otherwise
"""
allowed_uris = self.get_redirect_uris()
return uri in allowed_uris
def get_environment_summary(self) -> dict:
"""
Get a summary of the current OAuth configuration.
Returns:
Dictionary with configuration summary (excluding secrets)
"""
return {
"base_url": self.base_url,
"external_url": self.external_url,
"effective_oauth_url": self.get_oauth_base_url(),
"redirect_uri": self.redirect_uri,
"redirect_path": self.redirect_path,
"client_configured": bool(self.client_id),
"oauth21_enabled": self.oauth21_enabled,
"external_oauth21_provider": self.external_oauth21_provider,
"pkce_required": self.pkce_required,
"transport_mode": self._transport_mode,
"total_redirect_uris": len(self.get_redirect_uris()),
"total_allowed_origins": len(self.get_allowed_origins()),
}
def set_transport_mode(self, mode: str) -> None:
"""
Set the current transport mode for OAuth callback handling.
Args:
mode: Transport mode ("stdio", "streamable-http", etc.)
"""
self._transport_mode = mode
def get_transport_mode(self) -> str:
"""
Get the current transport mode.
Returns:
Current transport mode
"""
return self._transport_mode
def is_oauth21_enabled(self) -> bool:
"""
Check if OAuth 2.1 mode is enabled.
Returns:
True if OAuth 2.1 is enabled
"""
return self.oauth21_enabled
def is_external_oauth21_provider(self) -> bool:
"""
Check if external OAuth 2.1 provider mode is enabled.
When enabled, the server expects external OAuth flow with bearer tokens
in Authorization headers for tool calls. Protocol-level auth is disabled.
Returns:
True if external OAuth 2.1 provider is enabled
"""
return self.external_oauth21_provider
def detect_oauth_version(self, request_params: Dict[str, Any]) -> str:
"""
Detect OAuth version based on request parameters.
This method implements a conservative detection strategy:
- Only returns "oauth21" when we have clear indicators
- Defaults to "oauth20" for backward compatibility
- Respects the global oauth21_enabled flag
Args:
request_params: Request parameters from authorization or token request
Returns:
"oauth21" or "oauth20" based on detection
"""
# If OAuth 2.1 is not enabled globally, always return OAuth 2.0
if not self.oauth21_enabled:
return "oauth20"
# Use the structured type for cleaner detection logic
from auth.oauth_types import OAuthVersionDetectionParams
params = OAuthVersionDetectionParams.from_request(request_params)
# Clear OAuth 2.1 indicator: PKCE is present
if params.has_pkce:
return "oauth21"
# Additional detection: Check if we have an active OAuth 2.1 session
# This is important for tool calls where PKCE params aren't available
authenticated_user = request_params.get("authenticated_user")
if authenticated_user:
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
if store.has_session(authenticated_user):
return "oauth21"
except (ImportError, AttributeError, RuntimeError):
pass # Fall back to OAuth 2.0 if session check fails
# For public clients in OAuth 2.1 mode, we require PKCE
# But since they didn't send PKCE, fall back to OAuth 2.0
# This ensures backward compatibility
# Default to OAuth 2.0 for maximum compatibility
return "oauth20"
def get_authorization_server_metadata(
self, scopes: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Get OAuth authorization server metadata per RFC 8414.
Args:
scopes: Optional list of supported scopes to include in metadata
Returns:
Authorization server metadata dictionary
"""
oauth_base = self.get_oauth_base_url()
metadata = {
"issuer": "https://accounts.google.com",
"authorization_endpoint": f"{oauth_base}/oauth2/authorize",
"token_endpoint": f"{oauth_base}/oauth2/token",
"registration_endpoint": f"{oauth_base}/oauth2/register",
"jwks_uri": "https://www.googleapis.com/oauth2/v3/certs",
"userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo",
"response_types_supported": ["code", "token"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"token_endpoint_auth_methods_supported": [
"client_secret_post",
"client_secret_basic",
],
"code_challenge_methods_supported": self.supported_code_challenge_methods,
}
# Include scopes if provided
if scopes is not None:
metadata["scopes_supported"] = scopes
# Add OAuth 2.1 specific metadata
if self.oauth21_enabled:
metadata["pkce_required"] = True
# OAuth 2.1 deprecates implicit flow
metadata["response_types_supported"] = ["code"]
# OAuth 2.1 requires exact redirect URI matching
metadata["require_exact_redirect_uri"] = True
return metadata
# Global configuration instance with thread-safe access
_oauth_config = None
_oauth_config_lock = RLock()
def get_oauth_config() -> OAuthConfig:
"""
Get the global OAuth configuration instance.
Thread-safe singleton accessor.
Returns:
The singleton OAuth configuration instance
"""
global _oauth_config
with _oauth_config_lock:
if _oauth_config is None:
_oauth_config = OAuthConfig()
return _oauth_config
def reload_oauth_config() -> OAuthConfig:
"""
Reload the OAuth configuration from environment variables.
Thread-safe reload that prevents races with concurrent access.
Returns:
The reloaded OAuth configuration instance
"""
global _oauth_config
with _oauth_config_lock:
_oauth_config = OAuthConfig()
return _oauth_config
# Convenience functions for backward compatibility
def get_oauth_base_url() -> str:
"""Get OAuth base URL."""
return get_oauth_config().get_oauth_base_url()
def get_redirect_uris() -> List[str]:
"""Get all valid OAuth redirect URIs."""
return get_oauth_config().get_redirect_uris()
def get_allowed_origins() -> List[str]:
"""Get allowed CORS origins."""
return get_oauth_config().get_allowed_origins()
def is_oauth_configured() -> bool:
"""Check if OAuth is properly configured."""
return get_oauth_config().is_configured()
def set_transport_mode(mode: str) -> None:
"""Set the current transport mode."""
get_oauth_config().set_transport_mode(mode)
def get_transport_mode() -> str:
"""Get the current transport mode."""
return get_oauth_config().get_transport_mode()
def is_oauth21_enabled() -> bool:
"""Check if OAuth 2.1 is enabled."""
return get_oauth_config().is_oauth21_enabled()
def get_oauth_redirect_uri() -> str:
"""Get the primary OAuth redirect URI."""
return get_oauth_config().redirect_uri
def is_stateless_mode() -> bool:
"""Check if stateless mode is enabled."""
return get_oauth_config().stateless_mode
def is_external_oauth21_provider() -> bool:
"""Check if external OAuth 2.1 provider mode is enabled."""
return get_oauth_config().is_external_oauth21_provider()

229
auth/oauth_responses.py Normal file
View File

@@ -0,0 +1,229 @@
"""
Shared OAuth callback response templates.
Provides reusable HTML response templates for OAuth authentication flows
to eliminate duplication between server.py and oauth_callback_server.py.
"""
from fastapi.responses import HTMLResponse
from typing import Optional
def create_error_response(error_message: str, status_code: int = 400) -> HTMLResponse:
"""
Create a standardized error response for OAuth failures.
Args:
error_message: The error message to display
status_code: HTTP status code (default 400)
Returns:
HTMLResponse with error page
"""
content = f"""
<html>
<head><title>Authentication Error</title></head>
<body style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; max-width: 600px; margin: 40px auto; padding: 20px; text-align: center;">
<h2 style="color: #d32f2f;">Authentication Error</h2>
<p>{error_message}</p>
<p>Please ensure you grant the requested permissions. You can close this tab and try again.</p>
</body>
</html>
"""
return HTMLResponse(content=content, status_code=status_code)
def create_success_response(verified_user_id: Optional[str] = None) -> HTMLResponse:
"""
Create a standardized success response for OAuth authentication.
Args:
verified_user_id: The authenticated user's email (optional)
Returns:
HTMLResponse with success page
"""
# Handle the case where no user ID is provided
user_display = verified_user_id if verified_user_id else "Google User"
content = f"""<html>
<head>
<title>Authentication Successful</title>
<style>
* {{
margin: 0;
padding: 0;
box-sizing: border-box;
}}
body {{
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
background: linear-gradient(135deg,#0f172a,#1e293b,#334155);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
color: #1a1a1a;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
}}
.container {{
background: rgba(255, 255, 255, 0.95);
backdrop-filter: blur(10px);
padding: 60px;
border-radius: 20px;
box-shadow: 0 30px 60px rgba(0, 0, 0, 0.12);
text-align: center;
max-width: 480px;
width: 90%;
transform: translateY(-20px);
animation: slideUp 0.6s ease-out;
}}
@keyframes slideUp {{
from {{
opacity: 0;
transform: translateY(0);
}}
to {{
opacity: 1;
transform: translateY(-20px);
}}
}}
.icon {{
width: 80px;
height: 80px;
margin: 0 auto 30px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
font-size: 40px;
color: white;
animation: pulse 2s ease-in-out infinite;
}}
@keyframes pulse {{
0%, 100% {{
transform: scale(1);
}}
50% {{
transform: scale(1.05);
}}
}}
h1 {{
font-size: 28px;
font-weight: 600;
margin-bottom: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}}
.message {{
font-size: 16px;
line-height: 1.6;
color: #4a5568;
margin-bottom: 20px;
}}
.user-id {{
font-weight: 600;
color: #667eea;
padding: 4px 12px;
background: rgba(102, 126, 234, 0.1);
border-radius: 6px;
display: inline-block;
margin: 0 4px;
}}
.button {{
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 16px 40px;
border: none;
border-radius: 30px;
font-size: 16px;
font-weight: 500;
cursor: pointer;
transition: all 0.3s ease;
margin-top: 30px;
display: inline-block;
text-decoration: none;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3);
}}
.button:hover {{
transform: translateY(-2px);
box-shadow: 0 7px 20px rgba(102, 126, 234, 0.4);
}}
.button:active {{
transform: translateY(0);
}}
.auto-close {{
font-size: 13px;
color: #a0aec0;
margin-top: 30px;
opacity: 0.8;
}}
</style>
<script>
function tryClose() {{
window.close();
// If window.close() was blocked by the browser, update the UI
setTimeout(function() {{
var btn = document.querySelector('.button');
if (btn) btn.textContent = 'You can close this tab manually';
var ac = document.querySelector('.auto-close');
if (ac) ac.style.display = 'none';
}}, 500);
}}
setTimeout(tryClose, 10000);
</script>
</head>
<body>
<div class="container">
<div class="icon">✓</div>
<h1>Authentication Successful</h1>
<div class="message">
You've been authenticated as <span class="user-id">{user_display}</span>
</div>
<div class="message">
Your credentials have been securely saved. You can now close this tab and retry your original command.
</div>
<button class="button" onclick="tryClose()">Close Tab</button>
<div class="auto-close">This tab will close automatically in 10 seconds</div>
</div>
</body>
</html>"""
return HTMLResponse(content=content)
def create_server_error_response(error_detail: str) -> HTMLResponse:
"""
Create a standardized server error response for OAuth processing failures.
Args:
error_detail: The detailed error message
Returns:
HTMLResponse with server error page
"""
content = f"""
<html>
<head><title>Authentication Processing Error</title></head>
<body style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; max-width: 600px; margin: 40px auto; padding: 20px; text-align: center;">
<h2 style="color: #d32f2f;">Authentication Processing Error</h2>
<p>An unexpected error occurred while processing your authentication: {error_detail}</p>
<p>Please try again. You can close this tab.</p>
</body>
</html>
"""
return HTMLResponse(content=content, status_code=500)

92
auth/oauth_types.py Normal file
View File

@@ -0,0 +1,92 @@
"""
Type definitions for OAuth authentication.
This module provides structured types for OAuth-related parameters,
improving code maintainability and type safety.
"""
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
from fastmcp.server.auth import AccessToken
class WorkspaceAccessToken(AccessToken):
"""AccessToken extended with workspace-specific fields."""
session_id: Optional[str] = None
sub: Optional[str] = None
email: Optional[str] = None
@dataclass
class OAuth21ServiceRequest:
"""
Encapsulates parameters for OAuth 2.1 service authentication requests.
This parameter object pattern reduces function complexity and makes
it easier to extend authentication parameters in the future.
"""
service_name: str
version: str
tool_name: str
user_google_email: str
required_scopes: List[str]
session_id: Optional[str] = None
auth_token_email: Optional[str] = None
allow_recent_auth: bool = False
context: Optional[Dict[str, Any]] = None
def to_legacy_params(self) -> dict:
"""Convert to legacy parameter format for backward compatibility."""
return {
"service_name": self.service_name,
"version": self.version,
"tool_name": self.tool_name,
"user_google_email": self.user_google_email,
"required_scopes": self.required_scopes,
}
@dataclass
class OAuthVersionDetectionParams:
"""
Parameters used for OAuth version detection.
Encapsulates the various signals we use to determine
whether a client supports OAuth 2.1 or needs OAuth 2.0.
"""
client_id: Optional[str] = None
client_secret: Optional[str] = None
code_challenge: Optional[str] = None
code_challenge_method: Optional[str] = None
code_verifier: Optional[str] = None
authenticated_user: Optional[str] = None
session_id: Optional[str] = None
@classmethod
def from_request(
cls, request_params: Dict[str, Any]
) -> "OAuthVersionDetectionParams":
"""Create from raw request parameters."""
return cls(
client_id=request_params.get("client_id"),
client_secret=request_params.get("client_secret"),
code_challenge=request_params.get("code_challenge"),
code_challenge_method=request_params.get("code_challenge_method"),
code_verifier=request_params.get("code_verifier"),
authenticated_user=request_params.get("authenticated_user"),
session_id=request_params.get("session_id"),
)
@property
def has_pkce(self) -> bool:
"""Check if PKCE parameters are present."""
return bool(self.code_challenge or self.code_verifier)
@property
def is_public_client(self) -> bool:
"""Check if this appears to be a public client (no secret)."""
return bool(self.client_id and not self.client_secret)

277
auth/permissions.py Normal file
View File

@@ -0,0 +1,277 @@
"""
Granular per-service permission levels.
Each service has named permission levels (cumulative), mapping to a list of
OAuth scopes. The levels for a service are ordered from least to most
permissive — requesting level N implicitly includes all scopes from levels < N.
Usage:
--permissions gmail:organize drive:readonly
Gmail levels: readonly, organize, drafts, send, full
Tasks levels: readonly, manage, full
Other services: readonly, full (extensible by adding entries to SERVICE_PERMISSION_LEVELS)
"""
import logging
from typing import Dict, FrozenSet, List, Optional, Tuple
from auth.scopes import (
GMAIL_READONLY_SCOPE,
GMAIL_LABELS_SCOPE,
GMAIL_MODIFY_SCOPE,
GMAIL_COMPOSE_SCOPE,
GMAIL_SEND_SCOPE,
GMAIL_SETTINGS_BASIC_SCOPE,
DRIVE_READONLY_SCOPE,
DRIVE_FILE_SCOPE,
DRIVE_SCOPE,
CALENDAR_READONLY_SCOPE,
CALENDAR_EVENTS_SCOPE,
CALENDAR_SCOPE,
DOCS_READONLY_SCOPE,
DOCS_WRITE_SCOPE,
SHEETS_READONLY_SCOPE,
SHEETS_WRITE_SCOPE,
CHAT_READONLY_SCOPE,
CHAT_WRITE_SCOPE,
CHAT_SPACES_SCOPE,
CHAT_SPACES_READONLY_SCOPE,
FORMS_BODY_SCOPE,
FORMS_BODY_READONLY_SCOPE,
FORMS_RESPONSES_READONLY_SCOPE,
SLIDES_SCOPE,
SLIDES_READONLY_SCOPE,
TASKS_SCOPE,
TASKS_READONLY_SCOPE,
CONTACTS_SCOPE,
CONTACTS_READONLY_SCOPE,
CUSTOM_SEARCH_SCOPE,
SCRIPT_PROJECTS_SCOPE,
SCRIPT_PROJECTS_READONLY_SCOPE,
SCRIPT_DEPLOYMENTS_SCOPE,
SCRIPT_DEPLOYMENTS_READONLY_SCOPE,
SCRIPT_PROCESSES_READONLY_SCOPE,
SCRIPT_METRICS_SCOPE,
)
logger = logging.getLogger(__name__)
# Ordered permission levels per service.
# Each entry is (level_name, [additional_scopes_at_this_level]).
# Scopes are CUMULATIVE: level N includes all scopes from levels 0..N.
SERVICE_PERMISSION_LEVELS: Dict[str, List[Tuple[str, List[str]]]] = {
"gmail": [
("readonly", [GMAIL_READONLY_SCOPE]),
("organize", [GMAIL_LABELS_SCOPE, GMAIL_MODIFY_SCOPE]),
("drafts", [GMAIL_COMPOSE_SCOPE]),
("send", [GMAIL_SEND_SCOPE]),
("full", [GMAIL_SETTINGS_BASIC_SCOPE]),
],
"drive": [
("readonly", [DRIVE_READONLY_SCOPE]),
("full", [DRIVE_SCOPE, DRIVE_FILE_SCOPE]),
],
"calendar": [
("readonly", [CALENDAR_READONLY_SCOPE]),
("full", [CALENDAR_SCOPE, CALENDAR_EVENTS_SCOPE]),
],
"docs": [
("readonly", [DOCS_READONLY_SCOPE, DRIVE_READONLY_SCOPE]),
("full", [DOCS_WRITE_SCOPE, DRIVE_READONLY_SCOPE, DRIVE_FILE_SCOPE]),
],
"sheets": [
("readonly", [SHEETS_READONLY_SCOPE, DRIVE_READONLY_SCOPE]),
("full", [SHEETS_WRITE_SCOPE, DRIVE_READONLY_SCOPE]),
],
"chat": [
("readonly", [CHAT_READONLY_SCOPE, CHAT_SPACES_READONLY_SCOPE]),
("full", [CHAT_WRITE_SCOPE, CHAT_SPACES_SCOPE]),
],
"forms": [
("readonly", [FORMS_BODY_READONLY_SCOPE, FORMS_RESPONSES_READONLY_SCOPE]),
("full", [FORMS_BODY_SCOPE, FORMS_RESPONSES_READONLY_SCOPE]),
],
"slides": [
("readonly", [SLIDES_READONLY_SCOPE]),
("full", [SLIDES_SCOPE]),
],
"tasks": [
("readonly", [TASKS_READONLY_SCOPE]),
("manage", [TASKS_SCOPE]),
("full", []),
],
"contacts": [
("readonly", [CONTACTS_READONLY_SCOPE]),
("full", [CONTACTS_SCOPE]),
],
"search": [
("readonly", [CUSTOM_SEARCH_SCOPE]),
("full", [CUSTOM_SEARCH_SCOPE]),
],
"appscript": [
(
"readonly",
[
SCRIPT_PROJECTS_READONLY_SCOPE,
SCRIPT_DEPLOYMENTS_READONLY_SCOPE,
SCRIPT_PROCESSES_READONLY_SCOPE,
SCRIPT_METRICS_SCOPE,
DRIVE_READONLY_SCOPE,
],
),
(
"full",
[
SCRIPT_PROJECTS_SCOPE,
SCRIPT_DEPLOYMENTS_SCOPE,
SCRIPT_PROCESSES_READONLY_SCOPE,
SCRIPT_METRICS_SCOPE,
DRIVE_FILE_SCOPE,
],
),
],
}
# Actions denied at specific permission levels.
# Maps service -> level -> frozenset of denied action names.
# Levels not listed here (or services without entries) deny nothing.
SERVICE_DENIED_ACTIONS: Dict[str, Dict[str, FrozenSet[str]]] = {
"tasks": {
"manage": frozenset({"delete", "clear_completed"}),
},
}
def is_action_denied(service: str, action: str) -> bool:
"""Check whether *action* is denied for *service* under current permissions.
Returns ``False`` when granular permissions mode is not active, when the
service has no permission entry, or when the configured level does not
deny the action.
"""
if _PERMISSIONS is None:
return False
level = _PERMISSIONS.get(service)
if level is None:
return False
denied = SERVICE_DENIED_ACTIONS.get(service, {}).get(level, frozenset())
return action in denied
# Module-level state: parsed --permissions config
# Dict mapping service_name -> level_name, e.g. {"gmail": "organize"}
_PERMISSIONS: Optional[Dict[str, str]] = None
def set_permissions(permissions: Optional[Dict[str, str]]) -> None:
"""Set granular permissions from parsed --permissions argument."""
global _PERMISSIONS
_PERMISSIONS = permissions
if permissions is not None:
logger.info("Granular permissions set: %s", permissions)
def get_permissions() -> Optional[Dict[str, str]]:
"""Return current permissions dict, or None if not using granular mode."""
return _PERMISSIONS
def is_permissions_mode() -> bool:
"""Check if granular permissions mode is active."""
return _PERMISSIONS is not None
def get_scopes_for_permission(service: str, level: str) -> List[str]:
"""
Get cumulative scopes for a service at a given permission level.
Returns all scopes up to and including the named level.
Raises ValueError if service or level is unknown.
"""
levels = SERVICE_PERMISSION_LEVELS.get(service)
if levels is None:
raise ValueError(f"Unknown service: '{service}'")
cumulative: List[str] = []
found = False
for level_name, level_scopes in levels:
cumulative.extend(level_scopes)
if level_name == level:
found = True
break
if not found:
valid = [name for name, _ in levels]
raise ValueError(
f"Unknown permission level '{level}' for service '{service}'. "
f"Valid levels: {valid}"
)
return sorted(set(cumulative))
def get_all_permission_scopes() -> List[str]:
"""
Get the combined scopes for all services at their configured permission levels.
Only meaningful when is_permissions_mode() is True.
"""
if _PERMISSIONS is None:
return []
all_scopes: set = set()
for service, level in _PERMISSIONS.items():
all_scopes.update(get_scopes_for_permission(service, level))
return list(all_scopes)
def get_allowed_scopes_set() -> Optional[set]:
"""
Get the set of allowed scopes under permissions mode (for tool filtering).
Returns None if permissions mode is not active.
"""
if _PERMISSIONS is None:
return None
return set(get_all_permission_scopes())
def get_valid_levels(service: str) -> List[str]:
"""Get valid permission level names for a service."""
levels = SERVICE_PERMISSION_LEVELS.get(service)
if levels is None:
return []
return [name for name, _ in levels]
def parse_permissions_arg(permissions_list: List[str]) -> Dict[str, str]:
"""
Parse --permissions arguments like ["gmail:organize", "drive:full"].
Returns dict mapping service -> level.
Raises ValueError on parse errors (unknown service, invalid level, bad format).
"""
result: Dict[str, str] = {}
for entry in permissions_list:
if ":" not in entry:
raise ValueError(
f"Invalid permission format: '{entry}'. "
f"Expected 'service:level' (e.g., 'gmail:organize', 'drive:readonly')"
)
service, level = entry.split(":", 1)
if service in result:
raise ValueError(f"Duplicate service in permissions: '{service}'")
if service not in SERVICE_PERMISSION_LEVELS:
raise ValueError(
f"Unknown service: '{service}'. "
f"Valid services: {sorted(SERVICE_PERMISSION_LEVELS.keys())}"
)
valid = get_valid_levels(service)
if level not in valid:
raise ValueError(
f"Unknown level '{level}' for service '{service}'. "
f"Valid levels: {valid}"
)
result[service] = level
return result

336
auth/scopes.py Normal file
View File

@@ -0,0 +1,336 @@
"""
Google Workspace OAuth Scopes
This module centralizes OAuth scope definitions for Google Workspace integration.
Separated from service_decorator.py to avoid circular imports.
"""
import logging
logger = logging.getLogger(__name__)
# Global variable to store enabled tools (set by main.py)
_ENABLED_TOOLS = None
# Individual OAuth Scope Constants
USERINFO_EMAIL_SCOPE = "https://www.googleapis.com/auth/userinfo.email"
USERINFO_PROFILE_SCOPE = "https://www.googleapis.com/auth/userinfo.profile"
OPENID_SCOPE = "openid"
CALENDAR_SCOPE = "https://www.googleapis.com/auth/calendar"
CALENDAR_READONLY_SCOPE = "https://www.googleapis.com/auth/calendar.readonly"
CALENDAR_EVENTS_SCOPE = "https://www.googleapis.com/auth/calendar.events"
# Google Drive scopes
DRIVE_SCOPE = "https://www.googleapis.com/auth/drive"
DRIVE_READONLY_SCOPE = "https://www.googleapis.com/auth/drive.readonly"
DRIVE_FILE_SCOPE = "https://www.googleapis.com/auth/drive.file"
# Google Docs scopes
DOCS_READONLY_SCOPE = "https://www.googleapis.com/auth/documents.readonly"
DOCS_WRITE_SCOPE = "https://www.googleapis.com/auth/documents"
# Gmail API scopes
GMAIL_READONLY_SCOPE = "https://www.googleapis.com/auth/gmail.readonly"
GMAIL_SEND_SCOPE = "https://www.googleapis.com/auth/gmail.send"
GMAIL_COMPOSE_SCOPE = "https://www.googleapis.com/auth/gmail.compose"
GMAIL_MODIFY_SCOPE = "https://www.googleapis.com/auth/gmail.modify"
GMAIL_LABELS_SCOPE = "https://www.googleapis.com/auth/gmail.labels"
GMAIL_SETTINGS_BASIC_SCOPE = "https://www.googleapis.com/auth/gmail.settings.basic"
# Google Chat API scopes
CHAT_READONLY_SCOPE = "https://www.googleapis.com/auth/chat.messages.readonly"
CHAT_WRITE_SCOPE = "https://www.googleapis.com/auth/chat.messages"
CHAT_SPACES_SCOPE = "https://www.googleapis.com/auth/chat.spaces"
CHAT_SPACES_READONLY_SCOPE = "https://www.googleapis.com/auth/chat.spaces.readonly"
# Google Sheets API scopes
SHEETS_READONLY_SCOPE = "https://www.googleapis.com/auth/spreadsheets.readonly"
SHEETS_WRITE_SCOPE = "https://www.googleapis.com/auth/spreadsheets"
# Google Forms API scopes
FORMS_BODY_SCOPE = "https://www.googleapis.com/auth/forms.body"
FORMS_BODY_READONLY_SCOPE = "https://www.googleapis.com/auth/forms.body.readonly"
FORMS_RESPONSES_READONLY_SCOPE = (
"https://www.googleapis.com/auth/forms.responses.readonly"
)
# Google Slides API scopes
SLIDES_SCOPE = "https://www.googleapis.com/auth/presentations"
SLIDES_READONLY_SCOPE = "https://www.googleapis.com/auth/presentations.readonly"
# Google Tasks API scopes
TASKS_SCOPE = "https://www.googleapis.com/auth/tasks"
TASKS_READONLY_SCOPE = "https://www.googleapis.com/auth/tasks.readonly"
# Google Contacts (People API) scopes
CONTACTS_SCOPE = "https://www.googleapis.com/auth/contacts"
CONTACTS_READONLY_SCOPE = "https://www.googleapis.com/auth/contacts.readonly"
# Google Custom Search API scope
CUSTOM_SEARCH_SCOPE = "https://www.googleapis.com/auth/cse"
# Google Apps Script API scopes
SCRIPT_PROJECTS_SCOPE = "https://www.googleapis.com/auth/script.projects"
SCRIPT_PROJECTS_READONLY_SCOPE = (
"https://www.googleapis.com/auth/script.projects.readonly"
)
SCRIPT_DEPLOYMENTS_SCOPE = "https://www.googleapis.com/auth/script.deployments"
SCRIPT_DEPLOYMENTS_READONLY_SCOPE = (
"https://www.googleapis.com/auth/script.deployments.readonly"
)
SCRIPT_PROCESSES_READONLY_SCOPE = "https://www.googleapis.com/auth/script.processes"
SCRIPT_METRICS_SCOPE = "https://www.googleapis.com/auth/script.metrics"
# Google scope hierarchy: broader scopes that implicitly cover narrower ones.
# See https://developers.google.com/gmail/api/auth/scopes,
# https://developers.google.com/drive/api/guides/api-specific-auth, etc.
SCOPE_HIERARCHY = {
GMAIL_MODIFY_SCOPE: {
GMAIL_READONLY_SCOPE,
GMAIL_SEND_SCOPE,
GMAIL_COMPOSE_SCOPE,
GMAIL_LABELS_SCOPE,
},
DRIVE_SCOPE: {DRIVE_READONLY_SCOPE, DRIVE_FILE_SCOPE},
CALENDAR_SCOPE: {CALENDAR_READONLY_SCOPE, CALENDAR_EVENTS_SCOPE},
DOCS_WRITE_SCOPE: {DOCS_READONLY_SCOPE},
SHEETS_WRITE_SCOPE: {SHEETS_READONLY_SCOPE},
SLIDES_SCOPE: {SLIDES_READONLY_SCOPE},
TASKS_SCOPE: {TASKS_READONLY_SCOPE},
CONTACTS_SCOPE: {CONTACTS_READONLY_SCOPE},
CHAT_WRITE_SCOPE: {CHAT_READONLY_SCOPE},
CHAT_SPACES_SCOPE: {CHAT_SPACES_READONLY_SCOPE},
FORMS_BODY_SCOPE: {FORMS_BODY_READONLY_SCOPE},
SCRIPT_PROJECTS_SCOPE: {SCRIPT_PROJECTS_READONLY_SCOPE},
SCRIPT_DEPLOYMENTS_SCOPE: {SCRIPT_DEPLOYMENTS_READONLY_SCOPE},
}
def has_required_scopes(available_scopes, required_scopes):
"""
Check if available scopes satisfy all required scopes, accounting for
Google's scope hierarchy (e.g., gmail.modify covers gmail.readonly).
Args:
available_scopes: Scopes the credentials have (set, list, or frozenset).
required_scopes: Scopes that are required (set, list, or frozenset).
Returns:
True if all required scopes are satisfied.
"""
available = set(available_scopes or [])
required = set(required_scopes or [])
# Expand available scopes with implied narrower scopes
expanded = set(available)
for broad_scope, covered in SCOPE_HIERARCHY.items():
if broad_scope in available:
expanded.update(covered)
return all(scope in expanded for scope in required)
# Base OAuth scopes required for user identification
BASE_SCOPES = [USERINFO_EMAIL_SCOPE, USERINFO_PROFILE_SCOPE, OPENID_SCOPE]
# Service-specific scope groups
DOCS_SCOPES = [
DOCS_READONLY_SCOPE,
DOCS_WRITE_SCOPE,
DRIVE_READONLY_SCOPE,
DRIVE_FILE_SCOPE,
]
CALENDAR_SCOPES = [CALENDAR_SCOPE, CALENDAR_READONLY_SCOPE, CALENDAR_EVENTS_SCOPE]
DRIVE_SCOPES = [DRIVE_SCOPE, DRIVE_READONLY_SCOPE, DRIVE_FILE_SCOPE]
GMAIL_SCOPES = [
GMAIL_READONLY_SCOPE,
GMAIL_SEND_SCOPE,
GMAIL_COMPOSE_SCOPE,
GMAIL_MODIFY_SCOPE,
GMAIL_LABELS_SCOPE,
GMAIL_SETTINGS_BASIC_SCOPE,
]
CHAT_SCOPES = [
CHAT_READONLY_SCOPE,
CHAT_WRITE_SCOPE,
CHAT_SPACES_SCOPE,
CHAT_SPACES_READONLY_SCOPE,
]
SHEETS_SCOPES = [SHEETS_READONLY_SCOPE, SHEETS_WRITE_SCOPE, DRIVE_READONLY_SCOPE]
FORMS_SCOPES = [
FORMS_BODY_SCOPE,
FORMS_BODY_READONLY_SCOPE,
FORMS_RESPONSES_READONLY_SCOPE,
]
SLIDES_SCOPES = [SLIDES_SCOPE, SLIDES_READONLY_SCOPE]
TASKS_SCOPES = [TASKS_SCOPE, TASKS_READONLY_SCOPE]
CONTACTS_SCOPES = [CONTACTS_SCOPE, CONTACTS_READONLY_SCOPE]
CUSTOM_SEARCH_SCOPES = [CUSTOM_SEARCH_SCOPE]
SCRIPT_SCOPES = [
SCRIPT_PROJECTS_SCOPE,
SCRIPT_PROJECTS_READONLY_SCOPE,
SCRIPT_DEPLOYMENTS_SCOPE,
SCRIPT_DEPLOYMENTS_READONLY_SCOPE,
SCRIPT_PROCESSES_READONLY_SCOPE, # Required for list_script_processes
SCRIPT_METRICS_SCOPE, # Required for get_script_metrics
DRIVE_FILE_SCOPE, # Required for list/delete script projects (uses Drive API)
]
# Tool-to-scopes mapping
TOOL_SCOPES_MAP = {
"gmail": GMAIL_SCOPES,
"drive": DRIVE_SCOPES,
"calendar": CALENDAR_SCOPES,
"docs": DOCS_SCOPES,
"sheets": SHEETS_SCOPES,
"chat": CHAT_SCOPES,
"forms": FORMS_SCOPES,
"slides": SLIDES_SCOPES,
"tasks": TASKS_SCOPES,
"contacts": CONTACTS_SCOPES,
"search": CUSTOM_SEARCH_SCOPES,
"appscript": SCRIPT_SCOPES,
}
# Tool-to-read-only-scopes mapping
TOOL_READONLY_SCOPES_MAP = {
"gmail": [GMAIL_READONLY_SCOPE],
"drive": [DRIVE_READONLY_SCOPE],
"calendar": [CALENDAR_READONLY_SCOPE],
"docs": [DOCS_READONLY_SCOPE, DRIVE_READONLY_SCOPE],
"sheets": [SHEETS_READONLY_SCOPE, DRIVE_READONLY_SCOPE],
"chat": [CHAT_READONLY_SCOPE, CHAT_SPACES_READONLY_SCOPE],
"forms": [FORMS_BODY_READONLY_SCOPE, FORMS_RESPONSES_READONLY_SCOPE],
"slides": [SLIDES_READONLY_SCOPE],
"tasks": [TASKS_READONLY_SCOPE],
"contacts": [CONTACTS_READONLY_SCOPE],
"search": CUSTOM_SEARCH_SCOPES,
"appscript": [
SCRIPT_PROJECTS_READONLY_SCOPE,
SCRIPT_DEPLOYMENTS_READONLY_SCOPE,
SCRIPT_PROCESSES_READONLY_SCOPE,
SCRIPT_METRICS_SCOPE,
DRIVE_READONLY_SCOPE,
],
}
def set_enabled_tools(enabled_tools):
"""
Set the globally enabled tools list.
Args:
enabled_tools: List of enabled tool names.
"""
global _ENABLED_TOOLS
_ENABLED_TOOLS = enabled_tools
logger.info(f"Enabled tools set for scope management: {enabled_tools}")
# Global variable to store read-only mode (set by main.py)
_READ_ONLY_MODE = False
def set_read_only(enabled: bool):
"""
Set the global read-only mode.
Args:
enabled: Boolean indicating if read-only mode should be enabled.
"""
global _READ_ONLY_MODE
_READ_ONLY_MODE = enabled
logger.info(f"Read-only mode set to: {enabled}")
def is_read_only_mode() -> bool:
"""Check if read-only mode is enabled."""
return _READ_ONLY_MODE
def get_all_read_only_scopes() -> list[str]:
"""Get all possible read-only scopes across all tools."""
all_scopes = set(BASE_SCOPES)
for scopes in TOOL_READONLY_SCOPES_MAP.values():
all_scopes.update(scopes)
return list(all_scopes)
def get_current_scopes():
"""
Returns scopes for currently enabled tools.
Uses globally set enabled tools or all tools if not set.
.. deprecated::
This function is a thin wrapper around get_scopes_for_tools() and exists
for backwards compatibility. Prefer using get_scopes_for_tools() directly
for new code, which allows explicit control over the tool list parameter.
Returns:
List of unique scopes for the enabled tools plus base scopes.
"""
return get_scopes_for_tools(_ENABLED_TOOLS)
def get_scopes_for_tools(enabled_tools=None):
"""
Returns scopes for enabled tools only.
Args:
enabled_tools: List of enabled tool names. If None, returns all scopes.
Returns:
List of unique scopes for the enabled tools plus base scopes.
"""
# Granular permissions mode overrides both full and read-only scope maps.
# Lazy import with guard to avoid circular dependency during module init
# (SCOPES = get_scopes_for_tools() runs at import time before auth.permissions
# is fully loaded, but permissions mode is never active at that point).
try:
from auth.permissions import is_permissions_mode, get_all_permission_scopes
if is_permissions_mode():
scopes = BASE_SCOPES.copy()
scopes.extend(get_all_permission_scopes())
logger.debug(
"Generated scopes from granular permissions: %d unique scopes",
len(set(scopes)),
)
return list(set(scopes))
except ImportError:
pass
if enabled_tools is None:
# Default behavior - return all scopes
enabled_tools = TOOL_SCOPES_MAP.keys()
# Start with base scopes (always required)
scopes = BASE_SCOPES.copy()
# Determine which map to use based on read-only mode
scope_map = TOOL_READONLY_SCOPES_MAP if _READ_ONLY_MODE else TOOL_SCOPES_MAP
mode_str = "read-only" if _READ_ONLY_MODE else "full"
# Add scopes for each enabled tool
for tool in enabled_tools:
if tool in scope_map:
scopes.extend(scope_map[tool])
logger.debug(
f"Generated {mode_str} scopes for tools {list(enabled_tools)}: {len(set(scopes))} unique scopes"
)
# Return unique scopes
return list(set(scopes))
# Combined scopes for all supported Google Workspace operations (backwards compatibility)
SCOPES = get_scopes_for_tools()

862
auth/service_decorator.py Normal file
View File

@@ -0,0 +1,862 @@
import inspect
import logging
import re
from functools import wraps
from typing import Dict, List, Optional, Any, Callable, Union, Tuple
from contextlib import ExitStack
from google.auth.exceptions import RefreshError
from googleapiclient.discovery import build
from fastmcp.server.dependencies import get_access_token, get_context
from auth.google_auth import get_authenticated_google_service, GoogleAuthenticationError
from auth.oauth21_session_store import (
get_auth_provider,
get_oauth21_session_store,
ensure_session_from_access_token,
)
from auth.oauth_config import (
is_oauth21_enabled,
get_oauth_config,
is_external_oauth21_provider,
)
from core.context import set_fastmcp_session_id
from auth.scopes import (
GMAIL_READONLY_SCOPE,
GMAIL_SEND_SCOPE,
GMAIL_COMPOSE_SCOPE,
GMAIL_MODIFY_SCOPE,
GMAIL_LABELS_SCOPE,
GMAIL_SETTINGS_BASIC_SCOPE,
DRIVE_SCOPE,
DRIVE_READONLY_SCOPE,
DRIVE_FILE_SCOPE,
DOCS_READONLY_SCOPE,
DOCS_WRITE_SCOPE,
CALENDAR_READONLY_SCOPE,
CALENDAR_EVENTS_SCOPE,
SHEETS_READONLY_SCOPE,
SHEETS_WRITE_SCOPE,
CHAT_READONLY_SCOPE,
CHAT_WRITE_SCOPE,
CHAT_SPACES_SCOPE,
CHAT_SPACES_READONLY_SCOPE,
FORMS_BODY_SCOPE,
FORMS_BODY_READONLY_SCOPE,
FORMS_RESPONSES_READONLY_SCOPE,
SLIDES_SCOPE,
SLIDES_READONLY_SCOPE,
TASKS_SCOPE,
TASKS_READONLY_SCOPE,
CONTACTS_SCOPE,
CONTACTS_READONLY_SCOPE,
CUSTOM_SEARCH_SCOPE,
SCRIPT_PROJECTS_SCOPE,
SCRIPT_PROJECTS_READONLY_SCOPE,
SCRIPT_DEPLOYMENTS_SCOPE,
SCRIPT_DEPLOYMENTS_READONLY_SCOPE,
has_required_scopes,
)
logger = logging.getLogger(__name__)
# Authentication helper functions
async def _get_auth_context(
tool_name: str,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Get authentication context from FastMCP.
Returns:
Tuple of (authenticated_user, auth_method, mcp_session_id)
"""
try:
ctx = get_context()
if not ctx:
return None, None, None
authenticated_user = await ctx.get_state("authenticated_user_email")
auth_method = await ctx.get_state("authenticated_via")
mcp_session_id = ctx.session_id if hasattr(ctx, "session_id") else None
if mcp_session_id:
set_fastmcp_session_id(mcp_session_id)
logger.info(
f"[{tool_name}] Auth from middleware: authenticated_user={authenticated_user}, auth_method={auth_method}, session_id={mcp_session_id}"
)
return authenticated_user, auth_method, mcp_session_id
except Exception as e:
logger.debug(f"[{tool_name}] Could not get FastMCP context: {e}")
return None, None, None
def _detect_oauth_version(
authenticated_user: Optional[str], mcp_session_id: Optional[str], tool_name: str
) -> bool:
"""
Detect whether to use OAuth 2.1 based on configuration and context.
Returns:
True if OAuth 2.1 should be used, False otherwise
"""
if not is_oauth21_enabled():
return False
# When OAuth 2.1 is enabled globally, ALWAYS use OAuth 2.1 for authenticated users
if authenticated_user:
logger.info(
f"[{tool_name}] OAuth 2.1 mode: Using OAuth 2.1 for authenticated user '{authenticated_user}'"
)
return True
# If FastMCP protocol-level auth is enabled, a validated access token should
# be available even if middleware state wasn't populated.
try:
if get_access_token() is not None:
logger.info(
f"[{tool_name}] OAuth 2.1 mode: Using OAuth 2.1 based on validated access token"
)
return True
except Exception as e:
logger.debug(
f"[{tool_name}] Could not inspect access token for OAuth mode: {e}"
)
# Only use version detection for unauthenticated requests
config = get_oauth_config()
request_params = {}
if mcp_session_id:
request_params["session_id"] = mcp_session_id
oauth_version = config.detect_oauth_version(request_params)
use_oauth21 = oauth_version == "oauth21"
logger.info(
f"[{tool_name}] OAuth version detected: {oauth_version}, will use OAuth 2.1: {use_oauth21}"
)
return use_oauth21
def _update_email_in_args(args: tuple, index: int, new_email: str) -> tuple:
"""Update email at specific index in args tuple."""
if index < len(args):
args_list = list(args)
args_list[index] = new_email
return tuple(args_list)
return args
def _override_oauth21_user_email(
use_oauth21: bool,
authenticated_user: Optional[str],
current_user_email: str,
args: tuple,
kwargs: dict,
param_names: List[str],
tool_name: str,
service_type: str = "",
) -> Tuple[str, tuple]:
"""
Override user_google_email with authenticated user when using OAuth 2.1.
Returns:
Tuple of (updated_user_email, updated_args)
"""
if not (
use_oauth21 and authenticated_user and current_user_email != authenticated_user
):
return current_user_email, args
service_suffix = f" for service '{service_type}'" if service_type else ""
logger.info(
f"[{tool_name}] OAuth 2.1: Overriding user_google_email from '{current_user_email}' to authenticated user '{authenticated_user}'{service_suffix}"
)
# Update in kwargs if present
if "user_google_email" in kwargs:
kwargs["user_google_email"] = authenticated_user
# Update in args if user_google_email is passed positionally
try:
user_email_index = param_names.index("user_google_email")
args = _update_email_in_args(args, user_email_index, authenticated_user)
except ValueError:
pass # user_google_email not in positional parameters
return authenticated_user, args
async def _authenticate_service(
use_oauth21: bool,
service_name: str,
service_version: str,
tool_name: str,
user_google_email: str,
resolved_scopes: List[str],
mcp_session_id: Optional[str],
authenticated_user: Optional[str],
) -> Tuple[Any, str]:
"""
Authenticate and get Google service using appropriate OAuth version.
Returns:
Tuple of (service, actual_user_email)
"""
if use_oauth21:
logger.debug(f"[{tool_name}] Using OAuth 2.1 flow")
return await get_authenticated_google_service_oauth21(
service_name=service_name,
version=service_version,
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=resolved_scopes,
session_id=mcp_session_id,
auth_token_email=authenticated_user,
allow_recent_auth=False,
)
else:
logger.debug(f"[{tool_name}] Using legacy OAuth 2.0 flow")
return await get_authenticated_google_service(
service_name=service_name,
version=service_version,
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=resolved_scopes,
session_id=mcp_session_id,
)
async def get_authenticated_google_service_oauth21(
service_name: str,
version: str,
tool_name: str,
user_google_email: str,
required_scopes: List[str],
session_id: Optional[str] = None,
auth_token_email: Optional[str] = None,
allow_recent_auth: bool = False,
) -> tuple[Any, str]:
"""
OAuth 2.1 authentication using the session store with security validation.
"""
provider = get_auth_provider()
access_token = get_access_token()
if provider and access_token:
token_email = None
if getattr(access_token, "claims", None):
token_email = access_token.claims.get("email")
resolved_email = token_email or auth_token_email or user_google_email
if not resolved_email:
raise GoogleAuthenticationError(
"Authenticated user email could not be determined from access token."
)
if auth_token_email and token_email and token_email != auth_token_email:
raise GoogleAuthenticationError(
"Access token email does not match authenticated session context."
)
if token_email and user_google_email and token_email != user_google_email:
raise GoogleAuthenticationError(
f"Authenticated account {token_email} does not match requested user {user_google_email}."
)
credentials = ensure_session_from_access_token(
access_token, resolved_email, session_id
)
if not credentials:
raise GoogleAuthenticationError(
"Unable to build Google credentials from authenticated access token."
)
scopes_available = set(credentials.scopes or [])
if not scopes_available and getattr(access_token, "scopes", None):
scopes_available = set(access_token.scopes)
if not has_required_scopes(scopes_available, required_scopes):
raise GoogleAuthenticationError(
f"OAuth credentials lack required scopes. Need: {required_scopes}, Have: {sorted(scopes_available)}"
)
service = build(service_name, version, credentials=credentials)
logger.info(f"[{tool_name}] Authenticated {service_name} for {resolved_email}")
return service, resolved_email
store = get_oauth21_session_store()
# Use the validation method to ensure session can only access its own credentials
credentials = store.get_credentials_with_validation(
requested_user_email=user_google_email,
session_id=session_id,
auth_token_email=auth_token_email,
allow_recent_auth=allow_recent_auth,
)
if not credentials:
raise GoogleAuthenticationError(
f"Access denied: Cannot retrieve credentials for {user_google_email}. "
f"You can only access credentials for your authenticated account."
)
if not credentials.scopes:
scopes_available = set(required_scopes)
else:
scopes_available = set(credentials.scopes)
if not has_required_scopes(scopes_available, required_scopes):
raise GoogleAuthenticationError(
f"OAuth 2.1 credentials lack required scopes. Need: {required_scopes}, Have: {sorted(scopes_available)}"
)
service = build(service_name, version, credentials=credentials)
logger.info(f"[{tool_name}] Authenticated {service_name} for {user_google_email}")
return service, user_google_email
def _extract_oauth21_user_email(
authenticated_user: Optional[str], func_name: str
) -> str:
"""
Extract user email for OAuth 2.1 mode.
Args:
authenticated_user: The authenticated user from context
func_name: Name of the function being decorated (for error messages)
Returns:
User email string
Raises:
Exception: If no authenticated user found in OAuth 2.1 mode
"""
if not authenticated_user:
raise Exception(
f"OAuth 2.1 mode requires an authenticated user for {func_name}, but none was found."
)
return authenticated_user
def _extract_oauth20_user_email(
args: tuple, kwargs: dict, wrapper_sig: inspect.Signature
) -> str:
"""
Extract user email for OAuth 2.0 mode from function arguments.
Args:
args: Positional arguments passed to wrapper
kwargs: Keyword arguments passed to wrapper
wrapper_sig: Function signature for parameter binding
Returns:
User email string
Raises:
Exception: If user_google_email parameter not found
"""
bound_args = wrapper_sig.bind(*args, **kwargs)
bound_args.apply_defaults()
user_google_email = bound_args.arguments.get("user_google_email")
if not user_google_email:
raise Exception("'user_google_email' parameter is required but was not found.")
return user_google_email
def _remove_user_email_arg_from_docstring(docstring: str) -> str:
"""
Remove user_google_email parameter documentation from docstring.
Args:
docstring: The original function docstring
Returns:
Modified docstring with user_google_email parameter removed
"""
if not docstring:
return docstring
# Pattern to match user_google_email parameter documentation
# Handles various formats like:
# - user_google_email (str): The user's Google email address. Required.
# - user_google_email: Description
# - user_google_email (str) - Description
patterns = [
r"^\s*user_google_email\s*\([^)]*\)\s*:\s*[^\n]*\.?\s*(?:Required\.?)?\s*\n",
r"^\s*user_google_email\s*:\s*[^\n]*\n",
r"^\s*user_google_email\s*\([^)]*\)\s*-\s*[^\n]*\n",
]
modified_docstring = docstring
for pattern in patterns:
modified_docstring = re.sub(pattern, "", modified_docstring, flags=re.MULTILINE)
# Clean up any sequence of 3 or more newlines that might have been created
modified_docstring = re.sub(r"\n{3,}", "\n\n", modified_docstring)
return modified_docstring
# Service configuration mapping
SERVICE_CONFIGS = {
"gmail": {"service": "gmail", "version": "v1"},
"drive": {"service": "drive", "version": "v3"},
"calendar": {"service": "calendar", "version": "v3"},
"docs": {"service": "docs", "version": "v1"},
"sheets": {"service": "sheets", "version": "v4"},
"chat": {"service": "chat", "version": "v1"},
"forms": {"service": "forms", "version": "v1"},
"slides": {"service": "slides", "version": "v1"},
"tasks": {"service": "tasks", "version": "v1"},
"people": {"service": "people", "version": "v1"},
"customsearch": {"service": "customsearch", "version": "v1"},
"script": {"service": "script", "version": "v1"},
}
# Scope group definitions for easy reference
SCOPE_GROUPS = {
# Gmail scopes
"gmail_read": GMAIL_READONLY_SCOPE,
"gmail_send": GMAIL_SEND_SCOPE,
"gmail_compose": GMAIL_COMPOSE_SCOPE,
"gmail_modify": GMAIL_MODIFY_SCOPE,
"gmail_labels": GMAIL_LABELS_SCOPE,
"gmail_settings_basic": GMAIL_SETTINGS_BASIC_SCOPE,
# Drive scopes
"drive": DRIVE_SCOPE,
"drive_read": DRIVE_READONLY_SCOPE,
"drive_file": DRIVE_FILE_SCOPE,
# Docs scopes
"docs_read": DOCS_READONLY_SCOPE,
"docs_write": DOCS_WRITE_SCOPE,
# Calendar scopes
"calendar_read": CALENDAR_READONLY_SCOPE,
"calendar_events": CALENDAR_EVENTS_SCOPE,
# Sheets scopes
"sheets_read": SHEETS_READONLY_SCOPE,
"sheets_write": SHEETS_WRITE_SCOPE,
# Chat scopes
"chat_read": CHAT_READONLY_SCOPE,
"chat_write": CHAT_WRITE_SCOPE,
"chat_spaces": CHAT_SPACES_SCOPE,
"chat_spaces_readonly": CHAT_SPACES_READONLY_SCOPE,
# Forms scopes
"forms": FORMS_BODY_SCOPE,
"forms_read": FORMS_BODY_READONLY_SCOPE,
"forms_responses_read": FORMS_RESPONSES_READONLY_SCOPE,
# Slides scopes
"slides": SLIDES_SCOPE,
"slides_read": SLIDES_READONLY_SCOPE,
# Tasks scopes
"tasks": TASKS_SCOPE,
"tasks_read": TASKS_READONLY_SCOPE,
# Contacts scopes
"contacts": CONTACTS_SCOPE,
"contacts_read": CONTACTS_READONLY_SCOPE,
# Custom Search scope
"customsearch": CUSTOM_SEARCH_SCOPE,
# Apps Script scopes
"script_readonly": SCRIPT_PROJECTS_READONLY_SCOPE,
"script_projects": SCRIPT_PROJECTS_SCOPE,
"script_deployments": SCRIPT_DEPLOYMENTS_SCOPE,
"script_deployments_readonly": SCRIPT_DEPLOYMENTS_READONLY_SCOPE,
}
def _resolve_scopes(scopes: Union[str, List[str]]) -> List[str]:
"""Resolve scope names to actual scope URLs."""
if isinstance(scopes, str):
if scopes in SCOPE_GROUPS:
return [SCOPE_GROUPS[scopes]]
else:
return [scopes]
resolved = []
for scope in scopes:
if scope in SCOPE_GROUPS:
resolved.append(SCOPE_GROUPS[scope])
else:
resolved.append(scope)
return resolved
def _handle_token_refresh_error(
error: RefreshError, user_email: str, service_name: str
) -> str:
"""
Handle token refresh errors gracefully, particularly expired/revoked tokens.
Args:
error: The RefreshError that occurred
user_email: User's email address
service_name: Name of the Google service
Returns:
A user-friendly error message with instructions for reauthentication
"""
error_str = str(error)
if (
"invalid_grant" in error_str.lower()
or "expired or revoked" in error_str.lower()
):
logger.warning(
f"Token expired or revoked for user {user_email} accessing {service_name}"
)
service_display_name = f"Google {service_name.title()}"
if is_oauth21_enabled():
if is_external_oauth21_provider():
oauth21_step = (
"Provide a valid OAuth 2.1 bearer token in the Authorization header"
)
else:
oauth21_step = "Sign in through your MCP client's OAuth 2.1 flow"
return (
f"**Authentication Required: Token Expired/Revoked for {service_display_name}**\n\n"
f"Your Google authentication token for {user_email} has expired or been revoked. "
f"This commonly happens when:\n"
f"- The token has been unused for an extended period\n"
f"- You've changed your Google account password\n"
f"- You've revoked access to the application\n\n"
f"**To resolve this, please:**\n"
f"1. {oauth21_step}\n"
f"2. Retry your original command\n\n"
f"The application will automatically use the new credentials once authentication is complete."
)
return (
f"**Authentication Required: Token Expired/Revoked for {service_display_name}**\n\n"
f"Your Google authentication token for {user_email} has expired or been revoked. "
f"This commonly happens when:\n"
f"- The token has been unused for an extended period\n"
f"- You've changed your Google account password\n"
f"- You've revoked access to the application\n\n"
f"**To resolve this, please:**\n"
f"1. Run `start_google_auth` with your email ({user_email}) and service_name='{service_display_name}'\n"
f"2. Complete the authentication flow in your browser\n"
f"3. Retry your original command\n\n"
f"The application will automatically use the new credentials once authentication is complete."
)
else:
# Handle other types of refresh errors
logger.error(f"Unexpected refresh error for user {user_email}: {error}")
if is_oauth21_enabled():
if is_external_oauth21_provider():
return (
f"Authentication error occurred for {user_email}. "
"Please provide a valid OAuth 2.1 bearer token and retry."
)
return (
f"Authentication error occurred for {user_email}. "
"Please sign in via your MCP client's OAuth 2.1 flow and retry."
)
return (
f"Authentication error occurred for {user_email}. "
f"Please try running `start_google_auth` with your email and the appropriate service name to reauthenticate."
)
def require_google_service(
service_type: str,
scopes: Union[str, List[str]],
version: Optional[str] = None,
):
"""
Decorator that automatically handles Google service authentication and injection.
Args:
service_type: Type of Google service ("gmail", "drive", "calendar", etc.)
scopes: Required scopes (can be scope group names or actual URLs)
version: Service version (defaults to standard version for service type)
Usage:
@require_google_service("gmail", "gmail_read")
async def search_messages(service, user_google_email: str, query: str):
# service parameter is automatically injected
# Original authentication logic is handled automatically
"""
def decorator(func: Callable) -> Callable:
original_sig = inspect.signature(func)
params = list(original_sig.parameters.values())
# The decorated function must have 'service' as its first parameter.
if not params or params[0].name != "service":
raise TypeError(
f"Function '{func.__name__}' decorated with @require_google_service "
"must have 'service' as its first parameter."
)
# Create a new signature for the wrapper that excludes the 'service' parameter.
# In OAuth 2.1 mode, also exclude 'user_google_email' since it's automatically determined.
if is_oauth21_enabled():
# Remove both 'service' and 'user_google_email' parameters
filtered_params = [p for p in params[1:] if p.name != "user_google_email"]
wrapper_sig = original_sig.replace(parameters=filtered_params)
else:
# Only remove 'service' parameter for OAuth 2.0 mode
wrapper_sig = original_sig.replace(parameters=params[1:])
@wraps(func)
async def wrapper(*args, **kwargs):
# Note: `args` and `kwargs` are now the arguments for the *wrapper*,
# which does not include 'service'.
# Get authentication context early to determine OAuth mode
authenticated_user, auth_method, mcp_session_id = await _get_auth_context(
func.__name__
)
# Extract user_google_email based on OAuth mode
if is_oauth21_enabled():
user_google_email = _extract_oauth21_user_email(
authenticated_user, func.__name__
)
else:
user_google_email = _extract_oauth20_user_email(
args, kwargs, wrapper_sig
)
# Get service configuration from the decorator's arguments
if service_type not in SERVICE_CONFIGS:
raise Exception(f"Unknown service type: {service_type}")
config = SERVICE_CONFIGS[service_type]
service_name = config["service"]
service_version = version or config["version"]
# Resolve scopes
resolved_scopes = _resolve_scopes(scopes)
try:
tool_name = func.__name__
# Log authentication status
logger.debug(
f"[{tool_name}] Auth: {authenticated_user or 'none'} via {auth_method or 'none'} (session: {mcp_session_id[:8] if mcp_session_id else 'none'})"
)
# Detect OAuth version
use_oauth21 = _detect_oauth_version(
authenticated_user, mcp_session_id, tool_name
)
# In OAuth 2.1 mode, user_google_email is already set to authenticated_user
# In OAuth 2.0 mode, we may need to override it
if not is_oauth21_enabled():
wrapper_params = list(wrapper_sig.parameters.keys())
user_google_email, args = _override_oauth21_user_email(
use_oauth21,
authenticated_user,
user_google_email,
args,
kwargs,
wrapper_params,
tool_name,
)
# Authenticate service
service, actual_user_email = await _authenticate_service(
use_oauth21,
service_name,
service_version,
tool_name,
user_google_email,
resolved_scopes,
mcp_session_id,
authenticated_user,
)
except GoogleAuthenticationError as e:
logger.error(
f"[{tool_name}] GoogleAuthenticationError during authentication. "
f"Method={auth_method or 'none'}, User={authenticated_user or 'none'}, "
f"Service={service_name} v{service_version}, MCPSessionID={mcp_session_id or 'none'}: {e}"
)
# Re-raise the original error without wrapping it
raise
try:
# In OAuth 2.1 mode, we need to add user_google_email to kwargs since it was removed from signature
if is_oauth21_enabled():
kwargs["user_google_email"] = user_google_email
# Prepend the fetched service object to the original arguments
return await func(service, *args, **kwargs)
except RefreshError as e:
error_message = _handle_token_refresh_error(
e, actual_user_email, service_name
)
raise GoogleAuthenticationError(error_message)
finally:
if service:
service.close()
# Set the wrapper's signature to the one without 'service'
wrapper.__signature__ = wrapper_sig
# Conditionally modify docstring to remove user_google_email parameter documentation
if is_oauth21_enabled():
logger.debug(
"OAuth 2.1 mode enabled, removing user_google_email from docstring"
)
if func.__doc__:
wrapper.__doc__ = _remove_user_email_arg_from_docstring(func.__doc__)
# Attach required scopes to the wrapper for tool filtering
wrapper._required_google_scopes = _resolve_scopes(scopes)
return wrapper
return decorator
def require_multiple_services(service_configs: List[Dict[str, Any]]):
"""
Decorator for functions that need multiple Google services.
Args:
service_configs: List of service configurations, each containing:
- service_type: Type of service
- scopes: Required scopes
- param_name: Name to inject service as (e.g., 'drive_service', 'docs_service')
- version: Optional version override
Usage:
@require_multiple_services([
{"service_type": "drive", "scopes": "drive_read", "param_name": "drive_service"},
{"service_type": "docs", "scopes": "docs_read", "param_name": "docs_service"}
])
async def get_doc_with_metadata(drive_service, docs_service, user_google_email: str, doc_id: str):
# Both services are automatically injected
"""
def decorator(func: Callable) -> Callable:
original_sig = inspect.signature(func)
service_param_names = {config["param_name"] for config in service_configs}
params = list(original_sig.parameters.values())
# Remove injected service params from the wrapper signature; drop user_google_email only for OAuth 2.1.
filtered_params = [p for p in params if p.name not in service_param_names]
if is_oauth21_enabled():
filtered_params = [
p for p in filtered_params if p.name != "user_google_email"
]
wrapper_sig = original_sig.replace(parameters=filtered_params)
wrapper_param_names = [p.name for p in filtered_params]
@wraps(func)
async def wrapper(*args, **kwargs):
# Get authentication context early
tool_name = func.__name__
authenticated_user, _, mcp_session_id = await _get_auth_context(tool_name)
# Extract user_google_email based on OAuth mode
if is_oauth21_enabled():
user_google_email = _extract_oauth21_user_email(
authenticated_user, tool_name
)
else:
user_google_email = _extract_oauth20_user_email(
args, kwargs, wrapper_sig
)
# Authenticate all services
with ExitStack() as stack:
for config in service_configs:
service_type = config["service_type"]
scopes = config["scopes"]
param_name = config["param_name"]
version = config.get("version")
if service_type not in SERVICE_CONFIGS:
raise Exception(f"Unknown service type: {service_type}")
service_config = SERVICE_CONFIGS[service_type]
service_name = service_config["service"]
service_version = version or service_config["version"]
resolved_scopes = _resolve_scopes(scopes)
try:
# Detect OAuth version (simplified for multiple services)
use_oauth21 = (
is_oauth21_enabled() and authenticated_user is not None
)
# In OAuth 2.0 mode, we may need to override user_google_email
if not is_oauth21_enabled():
user_google_email, args = _override_oauth21_user_email(
use_oauth21,
authenticated_user,
user_google_email,
args,
kwargs,
wrapper_param_names,
tool_name,
service_type,
)
# Authenticate service
service, _ = await _authenticate_service(
use_oauth21,
service_name,
service_version,
tool_name,
user_google_email,
resolved_scopes,
mcp_session_id,
authenticated_user,
)
# Inject service with specified parameter name
kwargs[param_name] = service
stack.callback(service.close)
except GoogleAuthenticationError as e:
logger.error(
f"[{tool_name}] GoogleAuthenticationError for service '{service_type}' (user: {user_google_email}): {e}"
)
# Re-raise the original error without wrapping it
raise
# Call the original function with refresh error handling
try:
# In OAuth 2.1 mode, we need to add user_google_email to kwargs since it was removed from signature
if is_oauth21_enabled():
kwargs["user_google_email"] = user_google_email
return await func(*args, **kwargs)
except RefreshError as e:
# Handle token refresh errors gracefully
error_message = _handle_token_refresh_error(
e, user_google_email, "Multiple Services"
)
raise GoogleAuthenticationError(error_message)
# Set the wrapper's signature
wrapper.__signature__ = wrapper_sig
# Conditionally modify docstring to remove user_google_email parameter documentation
if is_oauth21_enabled():
logger.debug(
"OAuth 2.1 mode enabled, removing user_google_email from docstring"
)
if func.__doc__:
wrapper.__doc__ = _remove_user_email_arg_from_docstring(func.__doc__)
# Attach all required scopes to the wrapper for tool filtering
all_scopes = []
for config in service_configs:
all_scopes.extend(_resolve_scopes(config["scopes"]))
wrapper._required_google_scopes = all_scopes
return wrapper
return decorator