apply ruff formatting

This commit is contained in:
Taylor Wilsdon
2025-12-13 13:49:28 -08:00
parent 1d80a24ca4
commit 6b8352a354
50 changed files with 4010 additions and 2842 deletions

View File

@@ -1 +1 @@
# Make the auth directory a Python package
# Make the auth directory a Python package

View File

@@ -1,6 +1,7 @@
"""
Authentication middleware to populate context state with user information
"""
import jwt
import logging
import os
@@ -20,11 +21,11 @@ class AuthInfoMiddleware(Middleware):
Middleware to extract authentication information from JWT tokens
and populate the FastMCP context state for use in tools and prompts.
"""
def __init__(self):
super().__init__()
self.auth_provider_type = "GoogleProvider"
async def _process_request_for_auth(self, context: MiddlewareContext):
"""Helper to extract, verify, and store auth info from a request."""
if not context.fastmcp_context:
@@ -42,69 +43,101 @@ class AuthInfoMiddleware(Middleware):
headers = get_http_headers()
if headers:
logger.debug("Processing HTTP headers for authentication")
# Get the Authorization header
auth_header = headers.get("authorization", "")
if auth_header.startswith("Bearer "):
token_str = auth_header[7:] # Remove "Bearer " prefix
logger.debug("Found Bearer token")
# For Google OAuth tokens (ya29.*), we need to verify them differently
if token_str.startswith("ya29."):
logger.debug("Detected Google OAuth access token format")
# Verify the token to get user info
from core.server import get_auth_provider
auth_provider = get_auth_provider()
if auth_provider:
try:
# Verify the token
verified_auth = await auth_provider.verify_token(token_str)
verified_auth = await auth_provider.verify_token(
token_str
)
if verified_auth:
# Extract user info from verified token
user_email = None
if hasattr(verified_auth, 'claims'):
if hasattr(verified_auth, "claims"):
user_email = verified_auth.claims.get("email")
# Get expires_at, defaulting to 1 hour from now if not available
if hasattr(verified_auth, 'expires_at'):
if hasattr(verified_auth, "expires_at"):
expires_at = verified_auth.expires_at
else:
expires_at = int(time.time()) + 3600 # Default to 1 hour
expires_at = (
int(time.time()) + 3600
) # Default to 1 hour
# Get client_id from verified auth or use default
client_id = getattr(verified_auth, 'client_id', None) or "google"
client_id = (
getattr(verified_auth, "client_id", None)
or "google"
)
access_token = SimpleNamespace(
token=token_str,
client_id=client_id,
scopes=verified_auth.scopes if hasattr(verified_auth, 'scopes') else [],
scopes=verified_auth.scopes
if hasattr(verified_auth, "scopes")
else [],
session_id=f"google_oauth_{token_str[:8]}",
expires_at=expires_at,
# Add other fields that might be needed
sub=verified_auth.sub if hasattr(verified_auth, 'sub') else user_email,
email=user_email
sub=verified_auth.sub
if hasattr(verified_auth, "sub")
else user_email,
email=user_email,
)
# Store in context state - this is the authoritative authentication state
context.fastmcp_context.set_state("access_token", access_token)
mcp_session_id = getattr(context.fastmcp_context, "session_id", None)
context.fastmcp_context.set_state(
"access_token", access_token
)
mcp_session_id = getattr(
context.fastmcp_context, "session_id", None
)
ensure_session_from_access_token(
verified_auth,
user_email,
mcp_session_id,
)
context.fastmcp_context.set_state("access_token_obj", verified_auth)
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
context.fastmcp_context.set_state("token_type", "google_oauth")
context.fastmcp_context.set_state("user_email", user_email)
context.fastmcp_context.set_state("username", user_email)
context.fastmcp_context.set_state(
"access_token_obj", verified_auth
)
context.fastmcp_context.set_state(
"auth_provider_type", self.auth_provider_type
)
context.fastmcp_context.set_state(
"token_type", "google_oauth"
)
context.fastmcp_context.set_state(
"user_email", user_email
)
context.fastmcp_context.set_state(
"username", user_email
)
# Set the definitive authentication state
context.fastmcp_context.set_state("authenticated_user_email", user_email)
context.fastmcp_context.set_state("authenticated_via", "bearer_token")
logger.info(f"Authenticated via Google OAuth: {user_email}")
context.fastmcp_context.set_state(
"authenticated_user_email", user_email
)
context.fastmcp_context.set_state(
"authenticated_via", "bearer_token"
)
logger.info(
f"Authenticated via Google OAuth: {user_email}"
)
else:
logger.error("Failed to verify Google OAuth token")
# Don't set authenticated_user_email if verification failed
@@ -113,18 +146,29 @@ class AuthInfoMiddleware(Middleware):
# Still store the unverified token - service decorator will handle verification
access_token = SimpleNamespace(
token=token_str,
client_id=os.getenv("GOOGLE_OAUTH_CLIENT_ID", "google"),
client_id=os.getenv(
"GOOGLE_OAUTH_CLIENT_ID", "google"
),
scopes=[],
session_id=f"google_oauth_{token_str[:8]}",
expires_at=int(time.time()) + 3600, # Default to 1 hour
expires_at=int(time.time())
+ 3600, # Default to 1 hour
sub="unknown",
email=""
email="",
)
context.fastmcp_context.set_state(
"access_token", access_token
)
context.fastmcp_context.set_state(
"auth_provider_type", self.auth_provider_type
)
context.fastmcp_context.set_state(
"token_type", "google_oauth"
)
context.fastmcp_context.set_state("access_token", access_token)
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
context.fastmcp_context.set_state("token_type", "google_oauth")
else:
logger.warning("No auth provider available to verify Google token")
logger.warning(
"No auth provider available to verify Google token"
)
# Store unverified token
access_token = SimpleNamespace(
token=token_str,
@@ -133,51 +177,93 @@ class AuthInfoMiddleware(Middleware):
session_id=f"google_oauth_{token_str[:8]}",
expires_at=int(time.time()) + 3600, # Default to 1 hour
sub="unknown",
email=""
email="",
)
context.fastmcp_context.set_state("access_token", access_token)
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
context.fastmcp_context.set_state("token_type", "google_oauth")
context.fastmcp_context.set_state(
"access_token", access_token
)
context.fastmcp_context.set_state(
"auth_provider_type", self.auth_provider_type
)
context.fastmcp_context.set_state(
"token_type", "google_oauth"
)
else:
# Decode JWT to get user info
try:
token_payload = jwt.decode(
token_str,
options={"verify_signature": False}
token_str, options={"verify_signature": False}
)
logger.debug(f"JWT payload decoded: {list(token_payload.keys())}")
logger.debug(
f"JWT payload decoded: {list(token_payload.keys())}"
)
# Create an AccessToken-like object
access_token = SimpleNamespace(
token=token_str,
client_id=token_payload.get("client_id", "unknown"),
scopes=token_payload.get("scope", "").split() if token_payload.get("scope") else [],
session_id=token_payload.get("sid", token_payload.get("jti", token_payload.get("session_id", "unknown"))),
expires_at=token_payload.get("exp", 0)
scopes=token_payload.get("scope", "").split()
if token_payload.get("scope")
else [],
session_id=token_payload.get(
"sid",
token_payload.get(
"jti",
token_payload.get("session_id", "unknown"),
),
),
expires_at=token_payload.get("exp", 0),
)
# Store in context state
context.fastmcp_context.set_state("access_token", access_token)
context.fastmcp_context.set_state(
"access_token", access_token
)
# Store additional user info
context.fastmcp_context.set_state("user_id", token_payload.get("sub"))
context.fastmcp_context.set_state("username", token_payload.get("username", token_payload.get("email")))
context.fastmcp_context.set_state("name", token_payload.get("name"))
context.fastmcp_context.set_state("auth_time", token_payload.get("auth_time"))
context.fastmcp_context.set_state("issuer", token_payload.get("iss"))
context.fastmcp_context.set_state("audience", token_payload.get("aud"))
context.fastmcp_context.set_state("jti", token_payload.get("jti"))
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
context.fastmcp_context.set_state(
"user_id", token_payload.get("sub")
)
context.fastmcp_context.set_state(
"username",
token_payload.get(
"username", token_payload.get("email")
),
)
context.fastmcp_context.set_state(
"name", token_payload.get("name")
)
context.fastmcp_context.set_state(
"auth_time", token_payload.get("auth_time")
)
context.fastmcp_context.set_state(
"issuer", token_payload.get("iss")
)
context.fastmcp_context.set_state(
"audience", token_payload.get("aud")
)
context.fastmcp_context.set_state(
"jti", token_payload.get("jti")
)
context.fastmcp_context.set_state(
"auth_provider_type", self.auth_provider_type
)
# Set the definitive authentication state for JWT tokens
user_email = token_payload.get("email", token_payload.get("username"))
user_email = token_payload.get(
"email", token_payload.get("username")
)
if user_email:
context.fastmcp_context.set_state("authenticated_user_email", user_email)
context.fastmcp_context.set_state("authenticated_via", "jwt_token")
context.fastmcp_context.set_state(
"authenticated_user_email", user_email
)
context.fastmcp_context.set_state(
"authenticated_via", "jwt_token"
)
logger.debug("JWT token processed successfully")
except jwt.DecodeError as e:
logger.error(f"Failed to decode JWT: {e}")
except Exception as e:
@@ -185,44 +271,58 @@ class AuthInfoMiddleware(Middleware):
else:
logger.debug("No Bearer token in Authorization header")
else:
logger.debug("No HTTP headers available (might be using stdio transport)")
logger.debug(
"No HTTP headers available (might be using stdio transport)"
)
except Exception as e:
logger.debug(f"Could not get HTTP request: {e}")
# After trying HTTP headers, check for other authentication methods
# This consolidates all authentication logic in the middleware
if not context.fastmcp_context.get_state("authenticated_user_email"):
logger.debug("No authentication found via bearer token, checking other methods")
logger.debug(
"No authentication found via bearer token, checking other methods"
)
# Check transport mode
from core.config import get_transport_mode
transport_mode = get_transport_mode()
if transport_mode == "stdio":
# In stdio mode, check if there's a session with credentials
# This is ONLY safe in stdio mode because it's single-user
logger.debug("Checking for stdio mode authentication")
# Get the requested user from the context if available
requested_user = None
if hasattr(context, 'request') and hasattr(context.request, 'params'):
requested_user = context.request.params.get('user_google_email')
elif hasattr(context, 'arguments'):
if hasattr(context, "request") and hasattr(context.request, "params"):
requested_user = context.request.params.get("user_google_email")
elif hasattr(context, "arguments"):
# FastMCP may store arguments differently
requested_user = context.arguments.get('user_google_email')
requested_user = context.arguments.get("user_google_email")
if requested_user:
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
# Check if user has a recent session
if store.has_session(requested_user):
logger.debug(f"Using recent stdio session for {requested_user}")
logger.debug(
f"Using recent stdio session for {requested_user}"
)
# In stdio mode, we can trust the user has authenticated recently
context.fastmcp_context.set_state("authenticated_user_email", requested_user)
context.fastmcp_context.set_state("authenticated_via", "stdio_session")
context.fastmcp_context.set_state("auth_provider_type", "oauth21_stdio")
context.fastmcp_context.set_state(
"authenticated_user_email", requested_user
)
context.fastmcp_context.set_state(
"authenticated_via", "stdio_session"
)
context.fastmcp_context.set_state(
"auth_provider_type", "oauth21_stdio"
)
except Exception as e:
logger.debug(f"Error checking stdio session: {e}")
@@ -230,73 +330,95 @@ class AuthInfoMiddleware(Middleware):
if not context.fastmcp_context.get_state("authenticated_user_email"):
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
single_user = store.get_single_user_email()
if single_user:
logger.debug(
f"Defaulting to single stdio OAuth session for {single_user}"
)
context.fastmcp_context.set_state("authenticated_user_email", single_user)
context.fastmcp_context.set_state("authenticated_via", "stdio_single_session")
context.fastmcp_context.set_state("auth_provider_type", "oauth21_stdio")
context.fastmcp_context.set_state(
"authenticated_user_email", single_user
)
context.fastmcp_context.set_state(
"authenticated_via", "stdio_single_session"
)
context.fastmcp_context.set_state(
"auth_provider_type", "oauth21_stdio"
)
context.fastmcp_context.set_state("user_email", single_user)
context.fastmcp_context.set_state("username", single_user)
except Exception as e:
logger.debug(f"Error determining stdio single-user session: {e}")
logger.debug(
f"Error determining stdio single-user session: {e}"
)
# Check for MCP session binding
if not context.fastmcp_context.get_state("authenticated_user_email") and hasattr(context.fastmcp_context, 'session_id'):
if not context.fastmcp_context.get_state(
"authenticated_user_email"
) and hasattr(context.fastmcp_context, "session_id"):
mcp_session_id = context.fastmcp_context.session_id
if mcp_session_id:
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
# Check if this MCP session is bound to a user
bound_user = store.get_user_by_mcp_session(mcp_session_id)
if bound_user:
logger.debug(f"MCP session bound to {bound_user}")
context.fastmcp_context.set_state("authenticated_user_email", bound_user)
context.fastmcp_context.set_state("authenticated_via", "mcp_session_binding")
context.fastmcp_context.set_state("auth_provider_type", "oauth21_session")
context.fastmcp_context.set_state(
"authenticated_user_email", bound_user
)
context.fastmcp_context.set_state(
"authenticated_via", "mcp_session_binding"
)
context.fastmcp_context.set_state(
"auth_provider_type", "oauth21_session"
)
except Exception as e:
logger.debug(f"Error checking MCP session binding: {e}")
async def on_call_tool(self, context: MiddlewareContext, call_next):
"""Extract auth info from token and set in context state"""
logger.debug("Processing tool call authentication")
try:
await self._process_request_for_auth(context)
logger.debug("Passing to next handler")
result = await call_next(context)
logger.debug("Handler completed")
return result
except Exception as e:
# Check if this is an authentication error - don't log traceback for these
if "GoogleAuthenticationError" in str(type(e)) or "Access denied: Cannot retrieve credentials" in str(e):
if "GoogleAuthenticationError" in str(
type(e)
) or "Access denied: Cannot retrieve credentials" in str(e):
logger.info(f"Authentication check failed: {e}")
else:
logger.error(f"Error in on_call_tool middleware: {e}", exc_info=True)
raise
async def on_get_prompt(self, context: MiddlewareContext, call_next):
"""Extract auth info for prompt requests too"""
logger.debug("Processing prompt authentication")
try:
await self._process_request_for_auth(context)
logger.debug("Passing prompt to next handler")
result = await call_next(context)
logger.debug("Prompt handler completed")
return result
except Exception as e:
# Check if this is an authentication error - don't log traceback for these
if "GoogleAuthenticationError" in str(type(e)) or "Access denied: Cannot retrieve credentials" in str(e):
if "GoogleAuthenticationError" in str(
type(e)
) or "Access denied: Cannot retrieve credentials" in str(e):
logger.info(f"Authentication check failed in prompt: {e}")
else:
logger.error(f"Error in on_get_prompt middleware: {e}", exc_info=True)

View File

@@ -4,6 +4,7 @@ External OAuth Provider for Google Workspace MCP
Extends FastMCP's GoogleProvider to support external OAuth flows where
access tokens (ya29.*) are issued by external systems and need validation.
"""
import logging
import time
from typing import Optional
@@ -55,7 +56,7 @@ class ExternalOAuthProvider(GoogleProvider):
token=token,
token_uri="https://oauth2.googleapis.com/token",
client_id=self._client_id,
client_secret=self._client_secret
client_secret=self._client_secret,
)
# Validate token by calling userinfo API
@@ -63,20 +64,27 @@ class ExternalOAuthProvider(GoogleProvider):
if user_info and user_info.get("email"):
# Token is valid - create AccessToken object
logger.info(f"Validated external access token for: {user_info['email']}")
logger.info(
f"Validated external access token for: {user_info['email']}"
)
# Create a mock AccessToken that the middleware expects
# This matches the structure that FastMCP's AccessToken would have
from types import SimpleNamespace
scope_list = list(getattr(self, "required_scopes", []) or [])
access_token = SimpleNamespace(
token=token,
scopes=scope_list,
expires_at=int(time.time()) + 3600, # Default to 1-hour validity
claims={"email": user_info["email"], "sub": user_info.get("id")},
expires_at=int(time.time())
+ 3600, # Default to 1-hour validity
claims={
"email": user_info["email"],
"sub": user_info.get("id"),
},
client_id=self._client_id,
email=user_info["email"],
sub=user_info.get("id")
sub=user_info.get("id"),
)
return access_token
else:

View File

@@ -15,7 +15,7 @@ from google.auth.transport.requests import Request
from google.auth.exceptions import RefreshError
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from auth.scopes import SCOPES, get_current_scopes # noqa
from auth.scopes import SCOPES, get_current_scopes # noqa
from auth.oauth21_session_store import get_oauth21_session_store
from auth.credential_store import get_credential_store
from auth.oauth_config import get_oauth_config, is_stateless_mode
@@ -136,11 +136,15 @@ def save_credentials_to_session(session_id: str, credentials: Credentials):
client_secret=credentials.client_secret,
scopes=credentials.scopes,
expiry=credentials.expiry,
mcp_session_id=session_id
mcp_session_id=session_id,
)
logger.debug(
f"Credentials saved to OAuth21SessionStore for session_id: {session_id}, user: {user_email}"
)
logger.debug(f"Credentials saved to OAuth21SessionStore for session_id: {session_id}, user: {user_email}")
else:
logger.warning(f"Could not save credentials to session store - no user email found for session: {session_id}")
logger.warning(
f"Could not save credentials to session store - no user email found for session: {session_id}"
)
def load_credentials_from_session(session_id: str) -> Optional[Credentials]:
@@ -359,7 +363,9 @@ async def start_auth_flow(
try:
session_id = get_fastmcp_session_id()
except Exception as e:
logger.debug(f"Could not retrieve FastMCP session ID for state binding: {e}")
logger.debug(
f"Could not retrieve FastMCP session ID for state binding: {e}"
)
store = get_oauth21_session_store()
store.store_oauth_state(oauth_state, session_id=session_id)
@@ -460,16 +466,16 @@ def handle_auth_callback(
state_values = parse_qs(parsed_response.query).get("state")
state = state_values[0] if state_values else None
state_info = store.validate_and_consume_oauth_state(state, session_id=session_id)
state_info = store.validate_and_consume_oauth_state(
state, session_id=session_id
)
logger.debug(
"Validated OAuth callback state %s for session %s",
(state[:8] if state else "<missing>"),
state_info.get("session_id") or "<unknown>",
)
flow = create_oauth_flow(
scopes=scopes, redirect_uri=redirect_uri, state=state
)
flow = create_oauth_flow(scopes=scopes, redirect_uri=redirect_uri, state=state)
# Exchange the authorization code for credentials
# Note: fetch_token will use the redirect_uri configured in the flow
@@ -502,7 +508,7 @@ def handle_auth_callback(
scopes=credentials.scopes,
expiry=credentials.expiry,
mcp_session_id=session_id,
issuer="https://accounts.google.com" # Add issuer for Google tokens
issuer="https://accounts.google.com", # Add issuer for Google tokens
)
# If session_id is provided, also save to session cache for compatibility
@@ -546,7 +552,9 @@ def get_credentials(
# Try to get credentials by MCP session
credentials = store.get_credentials_by_mcp_session(session_id)
if credentials:
logger.info(f"[get_credentials] Found OAuth 2.1 credentials for MCP session {session_id}")
logger.info(
f"[get_credentials] Found OAuth 2.1 credentials for MCP session {session_id}"
)
# Check scopes
if not all(scope in credentials.scopes for scope in required_scopes):
@@ -562,7 +570,9 @@ def get_credentials(
# Try to refresh
try:
credentials.refresh(Request())
logger.info(f"[get_credentials] Refreshed OAuth 2.1 credentials for session {session_id}")
logger.info(
f"[get_credentials] Refreshed OAuth 2.1 credentials for session {session_id}"
)
# Update stored credentials
user_email = store.get_user_by_mcp_session(session_id)
if user_email:
@@ -572,11 +582,13 @@ def get_credentials(
refresh_token=credentials.refresh_token,
scopes=credentials.scopes,
expiry=credentials.expiry,
mcp_session_id=session_id
mcp_session_id=session_id,
)
return credentials
except Exception as e:
logger.error(f"[get_credentials] Failed to refresh OAuth 2.1 credentials: {e}")
logger.error(
f"[get_credentials] Failed to refresh OAuth 2.1 credentials: {e}"
)
return None
except ImportError:
pass # OAuth 2.1 store not available
@@ -692,7 +704,9 @@ def get_credentials(
credential_store = get_credential_store()
credential_store.store_credential(user_google_email, credentials)
else:
logger.info(f"Skipping credential file save in stateless mode for {user_google_email}")
logger.info(
f"Skipping credential file save in stateless mode for {user_google_email}"
)
# Also update OAuth21SessionStore
store = get_oauth21_session_store()
@@ -706,7 +720,7 @@ def get_credentials(
scopes=credentials.scopes,
expiry=credentials.expiry,
mcp_session_id=session_id,
issuer="https://accounts.google.com" # Add issuer for Google tokens
issuer="https://accounts.google.com", # Add issuer for Google tokens
)
if session_id: # Update session cache if it was the source or is active
@@ -795,9 +809,13 @@ async def get_authenticated_google_service(
# First try context variable (works in async context)
session_id = get_fastmcp_session_id()
if session_id:
logger.debug(f"[{tool_name}] Got FastMCP session ID from context: {session_id}")
logger.debug(
f"[{tool_name}] Got FastMCP session ID from context: {session_id}"
)
else:
logger.debug(f"[{tool_name}] Context variable returned None/empty session ID")
logger.debug(
f"[{tool_name}] Context variable returned None/empty session ID"
)
except Exception as e:
logger.debug(
f"[{tool_name}] Could not get FastMCP session from context: {e}"
@@ -807,17 +825,25 @@ async def get_authenticated_google_service(
if not session_id and get_fastmcp_context:
try:
fastmcp_ctx = get_fastmcp_context()
if fastmcp_ctx and hasattr(fastmcp_ctx, 'session_id'):
if fastmcp_ctx and hasattr(fastmcp_ctx, "session_id"):
session_id = fastmcp_ctx.session_id
logger.debug(f"[{tool_name}] Got FastMCP session ID directly: {session_id}")
logger.debug(
f"[{tool_name}] Got FastMCP session ID directly: {session_id}"
)
else:
logger.debug(f"[{tool_name}] FastMCP context exists but no session_id attribute")
logger.debug(
f"[{tool_name}] FastMCP context exists but no session_id attribute"
)
except Exception as e:
logger.debug(f"[{tool_name}] Could not get FastMCP context directly: {e}")
logger.debug(
f"[{tool_name}] Could not get FastMCP context directly: {e}"
)
# Final fallback: log if we still don't have session_id
if not session_id:
logger.warning(f"[{tool_name}] Unable to obtain FastMCP session ID from any source")
logger.warning(
f"[{tool_name}] Unable to obtain FastMCP session ID from any source"
)
logger.info(
f"[{tool_name}] Attempting to get authenticated {service_name} service. Email: '{user_google_email}', Session: '{session_id}'"
@@ -838,8 +864,12 @@ async def get_authenticated_google_service(
)
if not credentials or not credentials.valid:
logger.warning(f"[{tool_name}] No valid credentials. Email: '{user_google_email}'.")
logger.info(f"[{tool_name}] Valid email '{user_google_email}' provided, initiating auth flow.")
logger.warning(
f"[{tool_name}] No valid credentials. Email: '{user_google_email}'."
)
logger.info(
f"[{tool_name}] Valid email '{user_google_email}' provided, initiating auth flow."
)
# Ensure OAuth callback is available
from auth.oauth_callback_server import ensure_oauth_callback_available

View File

@@ -26,24 +26,26 @@ class MCPSessionMiddleware(BaseHTTPMiddleware):
Middleware that extracts session information from requests and makes it
available to MCP tool functions via context variables.
"""
async def dispatch(self, request: Request, call_next: Callable) -> Any:
"""Process request and set session context."""
logger.debug(f"MCPSessionMiddleware processing request: {request.method} {request.url.path}")
logger.debug(
f"MCPSessionMiddleware processing request: {request.method} {request.url.path}"
)
# Skip non-MCP paths
if not request.url.path.startswith("/mcp"):
logger.debug(f"Skipping non-MCP path: {request.url.path}")
return await call_next(request)
session_context = None
try:
# Extract session information
headers = dict(request.headers)
session_id = extract_session_from_headers(headers)
# Try to get OAuth 2.1 auth context from FastMCP
auth_context = None
user_email = None
@@ -52,28 +54,33 @@ class MCPSessionMiddleware(BaseHTTPMiddleware):
if hasattr(request.state, "auth"):
auth_context = request.state.auth
# Extract user email from auth claims if available
if hasattr(auth_context, 'claims') and auth_context.claims:
user_email = auth_context.claims.get('email')
if hasattr(auth_context, "claims") and auth_context.claims:
user_email = auth_context.claims.get("email")
# Check for FastMCP session ID (from streamable HTTP transport)
if hasattr(request.state, "session_id"):
mcp_session_id = request.state.session_id
logger.debug(f"Found FastMCP session ID: {mcp_session_id}")
# Also check Authorization header for bearer tokens
auth_header = headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer ") and not user_email:
if (
auth_header
and auth_header.lower().startswith("bearer ")
and not user_email
):
try:
import jwt
token = auth_header[7:] # Remove "Bearer " prefix
# Decode without verification to extract email
claims = jwt.decode(token, options={"verify_signature": False})
user_email = claims.get('email')
user_email = claims.get("email")
if user_email:
logger.debug(f"Extracted user email from JWT: {user_email}")
except Exception:
pass
# Build session context
if session_id or auth_context or user_email or mcp_session_id:
# Create session ID hierarchy: explicit session_id > Google user session > FastMCP session
@@ -82,10 +89,11 @@ class MCPSessionMiddleware(BaseHTTPMiddleware):
effective_session_id = f"google_{user_email}"
elif not effective_session_id and mcp_session_id:
effective_session_id = mcp_session_id
session_context = SessionContext(
session_id=effective_session_id,
user_id=user_email or (auth_context.user_id if auth_context else None),
user_id=user_email
or (auth_context.user_id if auth_context else None),
auth_context=auth_context,
request=request,
metadata={
@@ -93,19 +101,19 @@ class MCPSessionMiddleware(BaseHTTPMiddleware):
"method": request.method,
"user_email": user_email,
"mcp_session_id": mcp_session_id,
}
},
)
logger.debug(
f"MCP request with session: session_id={session_context.session_id}, "
f"user_id={session_context.user_id}, path={request.url.path}"
)
# Process request with session context
with SessionContextManager(session_context):
response = await call_next(request)
return response
except Exception as e:
logger.error(f"Error in MCP session middleware: {e}")
# Continue without session context

View File

@@ -34,7 +34,9 @@ def _normalize_expiry_to_naive_utc(expiry: Optional[Any]) -> Optional[datetime]:
try:
return expiry.astimezone(timezone.utc).replace(tzinfo=None)
except Exception: # pragma: no cover - defensive
logger.debug("Failed to normalize aware expiry; returning without tzinfo")
logger.debug(
"Failed to normalize aware expiry; returning without tzinfo"
)
return expiry.replace(tzinfo=None)
return expiry # Already naive; assumed to represent UTC
@@ -51,15 +53,15 @@ def _normalize_expiry_to_naive_utc(expiry: Optional[Any]) -> Optional[datetime]:
# Context variable to store the current session information
_current_session_context: contextvars.ContextVar[Optional['SessionContext']] = contextvars.ContextVar(
'current_session_context',
default=None
_current_session_context: contextvars.ContextVar[Optional["SessionContext"]] = (
contextvars.ContextVar("current_session_context", default=None)
)
@dataclass
class SessionContext:
"""Container for session-related information."""
session_id: Optional[str] = None
user_id: Optional[str] = None
auth_context: Optional[Any] = None
@@ -81,7 +83,9 @@ def set_session_context(context: Optional[SessionContext]):
"""
_current_session_context.set(context)
if context:
logger.debug(f"Set session context: session_id={context.session_id}, user_id={context.user_id}")
logger.debug(
f"Set session context: session_id={context.session_id}, user_id={context.user_id}"
)
else:
logger.debug("Cleared session context")
@@ -161,6 +165,7 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
# If no session found, create a temporary session ID from token hash
# This allows header-based authentication to work with session context
import hashlib
token_hash = hashlib.sha256(token.encode()).hexdigest()[:8]
return f"bearer_token_{token_hash}"
@@ -171,6 +176,7 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
# OAuth21SessionStore - Main Session Management
# =============================================================================
class OAuth21SessionStore:
"""
Global store for OAuth 2.1 authenticated sessions.
@@ -185,8 +191,12 @@ class OAuth21SessionStore:
def __init__(self):
self._sessions: Dict[str, Dict[str, Any]] = {}
self._mcp_session_mapping: Dict[str, str] = {} # Maps FastMCP session ID -> user email
self._session_auth_binding: Dict[str, str] = {} # Maps session ID -> authenticated user email (immutable)
self._mcp_session_mapping: Dict[
str, str
] = {} # Maps FastMCP session ID -> user email
self._session_auth_binding: Dict[
str, str
] = {} # Maps session ID -> authenticated user email (immutable)
self._oauth_states: Dict[str, Dict[str, Any]] = {}
self._lock = RLock()
@@ -258,7 +268,9 @@ class OAuth21SessionStore:
state_info = self._oauth_states.get(state)
if not state_info:
logger.error("SECURITY: OAuth callback received unknown or expired state")
logger.error(
"SECURITY: OAuth callback received unknown or expired state"
)
raise ValueError("Invalid or expired OAuth state parameter")
bound_session = state_info.get("session_id")
@@ -332,16 +344,26 @@ class OAuth21SessionStore:
# Create immutable session binding (first binding wins, cannot be changed)
if mcp_session_id not in self._session_auth_binding:
self._session_auth_binding[mcp_session_id] = user_email
logger.info(f"Created immutable session binding: {mcp_session_id} -> {user_email}")
logger.info(
f"Created immutable session binding: {mcp_session_id} -> {user_email}"
)
elif self._session_auth_binding[mcp_session_id] != user_email:
# Security: Attempt to bind session to different user
logger.error(f"SECURITY: Attempt to rebind session {mcp_session_id} from {self._session_auth_binding[mcp_session_id]} to {user_email}")
raise ValueError(f"Session {mcp_session_id} is already bound to a different user")
logger.error(
f"SECURITY: Attempt to rebind session {mcp_session_id} from {self._session_auth_binding[mcp_session_id]} to {user_email}"
)
raise ValueError(
f"Session {mcp_session_id} is already bound to a different user"
)
self._mcp_session_mapping[mcp_session_id] = user_email
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id}, mcp_session_id: {mcp_session_id})")
logger.info(
f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id}, mcp_session_id: {mcp_session_id})"
)
else:
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id})")
logger.info(
f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id})"
)
# Also create binding for the OAuth session ID
if session_id and session_id not in self._session_auth_binding:
@@ -382,7 +404,9 @@ class OAuth21SessionStore:
logger.error(f"Failed to create credentials for {user_email}: {e}")
return None
def get_credentials_by_mcp_session(self, mcp_session_id: str) -> Optional[Credentials]:
def get_credentials_by_mcp_session(
self, mcp_session_id: str
) -> Optional[Credentials]:
"""
Get Google credentials using FastMCP session ID.
@@ -407,7 +431,7 @@ class OAuth21SessionStore:
requested_user_email: str,
session_id: Optional[str] = None,
auth_token_email: Optional[str] = None,
allow_recent_auth: bool = False
allow_recent_auth: bool = False,
) -> Optional[Credentials]:
"""
Get Google credentials with session validation.
@@ -466,6 +490,7 @@ class OAuth21SessionStore:
# Check transport mode to ensure this is only used in stdio
try:
from core.config import get_transport_mode
transport_mode = get_transport_mode()
if transport_mode != "stdio":
logger.error(
@@ -533,7 +558,9 @@ class OAuth21SessionStore:
# Also remove from auth binding
if mcp_session_id in self._session_auth_binding:
del self._session_auth_binding[mcp_session_id]
logger.info(f"Removed OAuth 2.1 session for {user_email} and MCP mapping for {mcp_session_id}")
logger.info(
f"Removed OAuth 2.1 session for {user_email} and MCP mapping for {mcp_session_id}"
)
# Remove OAuth session binding if exists
if session_id and session_id in self._session_auth_binding:
@@ -612,7 +639,9 @@ def _resolve_client_credentials() -> Tuple[Optional[str], Optional[str]]:
try:
client_secret = secret_obj.get_secret_value() # type: ignore[call-arg]
except Exception as exc: # pragma: no cover - defensive
logger.debug(f"Failed to resolve client secret from provider: {exc}")
logger.debug(
f"Failed to resolve client secret from provider: {exc}"
)
elif isinstance(secret_obj, str):
client_secret = secret_obj
@@ -629,7 +658,9 @@ def _resolve_client_credentials() -> Tuple[Optional[str], Optional[str]]:
return client_id, client_secret
def _build_credentials_from_provider(access_token: AccessToken) -> Optional[Credentials]:
def _build_credentials_from_provider(
access_token: AccessToken,
) -> Optional[Credentials]:
"""Construct Google credentials from the provider cache."""
if not _auth_provider:
return None
@@ -640,10 +671,14 @@ def _build_credentials_from_provider(access_token: AccessToken) -> Optional[Cred
client_id, client_secret = _resolve_client_credentials()
refresh_token_value = getattr(_auth_provider, "_access_to_refresh", {}).get(access_token.token)
refresh_token_value = getattr(_auth_provider, "_access_to_refresh", {}).get(
access_token.token
)
refresh_token_obj = None
if refresh_token_value:
refresh_token_obj = getattr(_auth_provider, "_refresh_tokens", {}).get(refresh_token_value)
refresh_token_obj = getattr(_auth_provider, "_refresh_tokens", {}).get(
refresh_token_value
)
expiry = None
expires_at = getattr(access_entry, "expires_at", None)
@@ -730,7 +765,9 @@ def ensure_session_from_access_token(
return credentials
def get_credentials_from_token(access_token: str, user_email: Optional[str] = None) -> Optional[Credentials]:
def get_credentials_from_token(
access_token: str, user_email: Optional[str] = None
) -> Optional[Credentials]:
"""
Convert a bearer token to Google credentials.
@@ -753,14 +790,18 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
# If the FastMCP provider is managing tokens, sync from provider storage
if _auth_provider:
access_record = getattr(_auth_provider, "_access_tokens", {}).get(access_token)
access_record = getattr(_auth_provider, "_access_tokens", {}).get(
access_token
)
if access_record:
logger.debug("Building credentials from FastMCP provider cache")
return ensure_session_from_access_token(access_record, user_email)
# Otherwise, create minimal credentials with just the access token
# Assume token is valid for 1 hour (typical for Google tokens)
expiry = _normalize_expiry_to_naive_utc(datetime.now(timezone.utc) + timedelta(hours=1))
expiry = _normalize_expiry_to_naive_utc(
datetime.now(timezone.utc) + timedelta(hours=1)
)
client_id, client_secret = _resolve_client_credentials()
credentials = Credentials(
@@ -770,7 +811,7 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
client_id=client_id,
client_secret=client_secret,
scopes=None,
expiry=expiry
expiry=expiry,
)
logger.debug("Created fallback Google credentials from bearer token")
@@ -781,7 +822,9 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
return None
def store_token_session(token_response: dict, user_email: str, mcp_session_id: Optional[str] = None) -> str:
def store_token_session(
token_response: dict, user_email: str, mcp_session_id: Optional[str] = None
) -> str:
"""
Store a token response in the session store.
@@ -802,9 +845,12 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
if not mcp_session_id:
try:
from core.context import get_fastmcp_session_id
mcp_session_id = get_fastmcp_session_id()
if mcp_session_id:
logger.debug(f"Got FastMCP session ID from context: {mcp_session_id}")
logger.debug(
f"Got FastMCP session ID from context: {mcp_session_id}"
)
except Exception as e:
logger.debug(f"Could not get FastMCP session from context: {e}")
@@ -815,7 +861,9 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
client_id, client_secret = _resolve_client_credentials()
scopes = token_response.get("scope", "")
scopes_list = scopes.split() if scopes else None
expiry = datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600))
expiry = datetime.now(timezone.utc) + timedelta(
seconds=token_response.get("expires_in", 3600)
)
store.store_session(
user_email=user_email,
@@ -832,7 +880,9 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
)
if mcp_session_id:
logger.info(f"Stored token session for {user_email} with MCP session {mcp_session_id}")
logger.info(
f"Stored token session for {user_email} with MCP session {mcp_session_id}"
)
else:
logger.info(f"Stored token session for {user_email}")

View File

@@ -17,13 +17,18 @@ from fastapi.responses import FileResponse, JSONResponse
from typing import Optional
from urllib.parse import urlparse
from auth.scopes import SCOPES, get_current_scopes # noqa
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
from auth.scopes import SCOPES, get_current_scopes # noqa
from auth.oauth_responses import (
create_error_response,
create_success_response,
create_server_error_response,
)
from auth.google_auth import handle_auth_callback, check_client_secrets
from auth.oauth_config import get_oauth_redirect_uri
logger = logging.getLogger(__name__)
class MinimalOAuthServer:
"""
Minimal HTTP server for OAuth callbacks in stdio mode.
@@ -59,7 +64,9 @@ class MinimalOAuthServer:
return create_error_response(error_message)
if not code:
error_message = "Authentication failed: No authorization code received from Google."
error_message = (
"Authentication failed: No authorization code received from Google."
)
logger.error(error_message)
return create_error_response(error_message)
@@ -69,7 +76,9 @@ class MinimalOAuthServer:
if error_message:
return create_server_error_response(error_message)
logger.info(f"OAuth callback: Received code (state: {state}). Attempting to exchange for tokens.")
logger.info(
f"OAuth callback: Received code (state: {state}). Attempting to exchange for tokens."
)
# Session ID tracking removed - not needed
@@ -79,16 +88,20 @@ class MinimalOAuthServer:
scopes=get_current_scopes(),
authorization_response=str(request.url),
redirect_uri=redirect_uri,
session_id=None
session_id=None,
)
logger.info(f"OAuth callback: Successfully authenticated user: {verified_user_id} (state: {state}).")
logger.info(
f"OAuth callback: Successfully authenticated user: {verified_user_id} (state: {state})."
)
# Return success page using shared template
return create_success_response(verified_user_id)
except Exception as e:
error_message_detail = f"Error processing OAuth callback (state: {state}): {str(e)}"
error_message_detail = (
f"Error processing OAuth callback (state: {state}): {str(e)}"
)
logger.error(error_message_detail, exc_info=True)
return create_server_error_response(str(e))
@@ -101,24 +114,22 @@ class MinimalOAuthServer:
"""Serve a stored attachment file."""
storage = get_attachment_storage()
metadata = storage.get_attachment_metadata(file_id)
if not metadata:
return JSONResponse(
{"error": "Attachment not found or expired"},
status_code=404
{"error": "Attachment not found or expired"}, status_code=404
)
file_path = storage.get_attachment_path(file_id)
if not file_path:
return JSONResponse(
{"error": "Attachment file not found"},
status_code=404
{"error": "Attachment file not found"}, status_code=404
)
return FileResponse(
path=str(file_path),
filename=metadata["filename"],
media_type=metadata["mime_type"]
media_type=metadata["mime_type"],
)
def start(self) -> tuple[bool, str]:
@@ -136,9 +147,9 @@ class MinimalOAuthServer:
# Extract hostname from base_uri (e.g., "http://localhost" -> "localhost")
try:
parsed_uri = urlparse(self.base_uri)
hostname = parsed_uri.hostname or 'localhost'
hostname = parsed_uri.hostname or "localhost"
except Exception:
hostname = 'localhost'
hostname = "localhost"
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -156,7 +167,7 @@ class MinimalOAuthServer:
host=hostname,
port=self.port,
log_level="warning",
access_log=False
access_log=False,
)
self.server = uvicorn.Server(config)
asyncio.run(self.server.serve())
@@ -178,7 +189,9 @@ class MinimalOAuthServer:
result = s.connect_ex((hostname, self.port))
if result == 0:
self.is_running = True
logger.info(f"Minimal OAuth server started on {hostname}:{self.port}")
logger.info(
f"Minimal OAuth server started on {hostname}:{self.port}"
)
return True, ""
except Exception:
pass
@@ -195,7 +208,7 @@ class MinimalOAuthServer:
try:
if self.server:
if hasattr(self.server, 'should_exit'):
if hasattr(self.server, "should_exit"):
self.server.should_exit = True
if self.server_thread and self.server_thread.is_alive():
@@ -211,7 +224,10 @@ class MinimalOAuthServer:
# Global instance for stdio mode
_minimal_oauth_server: Optional[MinimalOAuthServer] = None
def ensure_oauth_callback_available(transport_mode: str = "stdio", port: int = 8000, base_uri: str = "http://localhost") -> tuple[bool, str]:
def ensure_oauth_callback_available(
transport_mode: str = "stdio", port: int = 8000, base_uri: str = "http://localhost"
) -> tuple[bool, str]:
"""
Ensure OAuth callback endpoint is available for the given transport mode.
@@ -230,7 +246,9 @@ def ensure_oauth_callback_available(transport_mode: str = "stdio", port: int = 8
if transport_mode == "streamable-http":
# In streamable-http mode, the main FastAPI server should handle callbacks
logger.debug("Using existing FastAPI server for OAuth callbacks (streamable-http mode)")
logger.debug(
"Using existing FastAPI server for OAuth callbacks (streamable-http mode)"
)
return True, ""
elif transport_mode == "stdio":
@@ -243,10 +261,14 @@ def ensure_oauth_callback_available(transport_mode: str = "stdio", port: int = 8
logger.info("Starting minimal OAuth server for stdio mode")
success, error_msg = _minimal_oauth_server.start()
if success:
logger.info(f"Minimal OAuth server successfully started on {base_uri}:{port}")
logger.info(
f"Minimal OAuth server successfully started on {base_uri}:{port}"
)
return True, ""
else:
logger.error(f"Failed to start minimal OAuth server on {base_uri}:{port}: {error_msg}")
logger.error(
f"Failed to start minimal OAuth server on {base_uri}:{port}: {error_msg}"
)
return False, error_msg
else:
logger.info("Minimal OAuth server is already running")
@@ -257,6 +279,7 @@ def ensure_oauth_callback_available(transport_mode: str = "stdio", port: int = 8
logger.error(error_msg)
return False, error_msg
def cleanup_oauth_callback_server():
"""Clean up the minimal OAuth server if it was started."""
global _minimal_oauth_server

View File

@@ -36,19 +36,31 @@ class OAuthConfig:
self.client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
# OAuth 2.1 configuration
self.oauth21_enabled = os.getenv("MCP_ENABLE_OAUTH21", "false").lower() == "true"
self.oauth21_enabled = (
os.getenv("MCP_ENABLE_OAUTH21", "false").lower() == "true"
)
self.pkce_required = self.oauth21_enabled # PKCE is mandatory in OAuth 2.1
self.supported_code_challenge_methods = ["S256", "plain"] if not self.oauth21_enabled else ["S256"]
self.supported_code_challenge_methods = (
["S256", "plain"] if not self.oauth21_enabled else ["S256"]
)
# External OAuth 2.1 provider configuration
self.external_oauth21_provider = os.getenv("EXTERNAL_OAUTH21_PROVIDER", "false").lower() == "true"
self.external_oauth21_provider = (
os.getenv("EXTERNAL_OAUTH21_PROVIDER", "false").lower() == "true"
)
if self.external_oauth21_provider and not self.oauth21_enabled:
raise ValueError("EXTERNAL_OAUTH21_PROVIDER requires MCP_ENABLE_OAUTH21=true")
raise ValueError(
"EXTERNAL_OAUTH21_PROVIDER requires MCP_ENABLE_OAUTH21=true"
)
# Stateless mode configuration
self.stateless_mode = os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
self.stateless_mode = (
os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
)
if self.stateless_mode and not self.oauth21_enabled:
raise ValueError("WORKSPACE_MCP_STATELESS_MODE requires MCP_ENABLE_OAUTH21=true")
raise ValueError(
"WORKSPACE_MCP_STATELESS_MODE requires MCP_ENABLE_OAUTH21=true"
)
# Transport mode (will be set at runtime)
self._transport_mode = "stdio" # Default
@@ -95,7 +107,12 @@ class OAuthConfig:
# Don't set FASTMCP_SERVER_AUTH if using external OAuth provider
# (external OAuth means protocol-level auth is disabled, only tool-level auth)
if not self.external_oauth21_provider:
_set_if_absent("FASTMCP_SERVER_AUTH", "fastmcp.server.auth.providers.google.GoogleProvider" if self.oauth21_enabled else None)
_set_if_absent(
"FASTMCP_SERVER_AUTH",
"fastmcp.server.auth.providers.google.GoogleProvider"
if self.oauth21_enabled
else None,
)
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_CLIENT_ID", self.client_id)
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_CLIENT_SECRET", self.client_secret)
@@ -135,11 +152,13 @@ class OAuthConfig:
origins.append(self.base_url)
# VS Code and development origins
origins.extend([
"vscode-webview://",
"https://vscode.dev",
"https://github.dev",
])
origins.extend(
[
"vscode-webview://",
"https://vscode.dev",
"https://github.dev",
]
)
# Custom origins from environment
custom_origins = os.getenv("OAUTH_ALLOWED_ORIGINS")
@@ -266,6 +285,7 @@ class OAuthConfig:
# Use the structured type for cleaner detection logic
from auth.oauth_types import OAuthVersionDetectionParams
params = OAuthVersionDetectionParams.from_request(request_params)
# Clear OAuth 2.1 indicator: PKCE is present
@@ -278,6 +298,7 @@ class OAuthConfig:
if authenticated_user:
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
if store.has_session(authenticated_user):
return "oauth21"
@@ -291,7 +312,9 @@ class OAuthConfig:
# Default to OAuth 2.0 for maximum compatibility
return "oauth20"
def get_authorization_server_metadata(self, scopes: Optional[List[str]] = None) -> Dict[str, Any]:
def get_authorization_server_metadata(
self, scopes: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Get OAuth authorization server metadata per RFC 8414.
@@ -311,7 +334,10 @@ class OAuthConfig:
"userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo",
"response_types_supported": ["code", "token"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
"token_endpoint_auth_methods_supported": [
"client_secret_post",
"client_secret_basic",
],
"code_challenge_methods_supported": self.supported_code_challenge_methods,
}

View File

@@ -220,4 +220,4 @@ def create_server_error_response(error_detail: str) -> HTMLResponse:
</body>
</html>
"""
return HTMLResponse(content=content, status_code=500)
return HTMLResponse(content=content, status_code=500)

View File

@@ -13,10 +13,11 @@ from typing import Optional, List, Dict, Any
class OAuth21ServiceRequest:
"""
Encapsulates parameters for OAuth 2.1 service authentication requests.
This parameter object pattern reduces function complexity and makes
it easier to extend authentication parameters in the future.
"""
service_name: str
version: str
tool_name: str
@@ -26,7 +27,7 @@ class OAuth21ServiceRequest:
auth_token_email: Optional[str] = None
allow_recent_auth: bool = False
context: Optional[Dict[str, Any]] = None
def to_legacy_params(self) -> dict:
"""Convert to legacy parameter format for backward compatibility."""
return {
@@ -42,10 +43,11 @@ class OAuth21ServiceRequest:
class OAuthVersionDetectionParams:
"""
Parameters used for OAuth version detection.
Encapsulates the various signals we use to determine
whether a client supports OAuth 2.1 or needs OAuth 2.0.
"""
client_id: Optional[str] = None
client_secret: Optional[str] = None
code_challenge: Optional[str] = None
@@ -53,9 +55,11 @@ class OAuthVersionDetectionParams:
code_verifier: Optional[str] = None
authenticated_user: Optional[str] = None
session_id: Optional[str] = None
@classmethod
def from_request(cls, request_params: Dict[str, Any]) -> "OAuthVersionDetectionParams":
def from_request(
cls, request_params: Dict[str, Any]
) -> "OAuthVersionDetectionParams":
"""Create from raw request parameters."""
return cls(
client_id=request_params.get("client_id"),
@@ -66,13 +70,13 @@ class OAuthVersionDetectionParams:
authenticated_user=request_params.get("authenticated_user"),
session_id=request_params.get("session_id"),
)
@property
def has_pkce(self) -> bool:
"""Check if PKCE parameters are present."""
return bool(self.code_challenge or self.code_verifier)
@property
def is_public_client(self) -> bool:
"""Check if this appears to be a public client (no secret)."""
return bool(self.client_id and not self.client_secret)
return bool(self.client_id and not self.client_secret)

View File

@@ -4,6 +4,7 @@ Google Workspace OAuth Scopes
This module centralizes OAuth scope definitions for Google Workspace integration.
Separated from service_decorator.py to avoid circular imports.
"""
import logging
logger = logging.getLogger(__name__)
@@ -12,79 +13,66 @@ logger = logging.getLogger(__name__)
_ENABLED_TOOLS = None
# Individual OAuth Scope Constants
USERINFO_EMAIL_SCOPE = 'https://www.googleapis.com/auth/userinfo.email'
USERINFO_PROFILE_SCOPE = 'https://www.googleapis.com/auth/userinfo.profile'
OPENID_SCOPE = 'openid'
CALENDAR_SCOPE = 'https://www.googleapis.com/auth/calendar'
CALENDAR_READONLY_SCOPE = 'https://www.googleapis.com/auth/calendar.readonly'
CALENDAR_EVENTS_SCOPE = 'https://www.googleapis.com/auth/calendar.events'
USERINFO_EMAIL_SCOPE = "https://www.googleapis.com/auth/userinfo.email"
USERINFO_PROFILE_SCOPE = "https://www.googleapis.com/auth/userinfo.profile"
OPENID_SCOPE = "openid"
CALENDAR_SCOPE = "https://www.googleapis.com/auth/calendar"
CALENDAR_READONLY_SCOPE = "https://www.googleapis.com/auth/calendar.readonly"
CALENDAR_EVENTS_SCOPE = "https://www.googleapis.com/auth/calendar.events"
# Google Drive scopes
DRIVE_SCOPE = 'https://www.googleapis.com/auth/drive'
DRIVE_READONLY_SCOPE = 'https://www.googleapis.com/auth/drive.readonly'
DRIVE_FILE_SCOPE = 'https://www.googleapis.com/auth/drive.file'
DRIVE_SCOPE = "https://www.googleapis.com/auth/drive"
DRIVE_READONLY_SCOPE = "https://www.googleapis.com/auth/drive.readonly"
DRIVE_FILE_SCOPE = "https://www.googleapis.com/auth/drive.file"
# Google Docs scopes
DOCS_READONLY_SCOPE = 'https://www.googleapis.com/auth/documents.readonly'
DOCS_WRITE_SCOPE = 'https://www.googleapis.com/auth/documents'
DOCS_READONLY_SCOPE = "https://www.googleapis.com/auth/documents.readonly"
DOCS_WRITE_SCOPE = "https://www.googleapis.com/auth/documents"
# Gmail API scopes
GMAIL_READONLY_SCOPE = 'https://www.googleapis.com/auth/gmail.readonly'
GMAIL_SEND_SCOPE = 'https://www.googleapis.com/auth/gmail.send'
GMAIL_COMPOSE_SCOPE = 'https://www.googleapis.com/auth/gmail.compose'
GMAIL_MODIFY_SCOPE = 'https://www.googleapis.com/auth/gmail.modify'
GMAIL_LABELS_SCOPE = 'https://www.googleapis.com/auth/gmail.labels'
GMAIL_SETTINGS_BASIC_SCOPE = 'https://www.googleapis.com/auth/gmail.settings.basic'
GMAIL_READONLY_SCOPE = "https://www.googleapis.com/auth/gmail.readonly"
GMAIL_SEND_SCOPE = "https://www.googleapis.com/auth/gmail.send"
GMAIL_COMPOSE_SCOPE = "https://www.googleapis.com/auth/gmail.compose"
GMAIL_MODIFY_SCOPE = "https://www.googleapis.com/auth/gmail.modify"
GMAIL_LABELS_SCOPE = "https://www.googleapis.com/auth/gmail.labels"
GMAIL_SETTINGS_BASIC_SCOPE = "https://www.googleapis.com/auth/gmail.settings.basic"
# Google Chat API scopes
CHAT_READONLY_SCOPE = 'https://www.googleapis.com/auth/chat.messages.readonly'
CHAT_WRITE_SCOPE = 'https://www.googleapis.com/auth/chat.messages'
CHAT_SPACES_SCOPE = 'https://www.googleapis.com/auth/chat.spaces'
CHAT_READONLY_SCOPE = "https://www.googleapis.com/auth/chat.messages.readonly"
CHAT_WRITE_SCOPE = "https://www.googleapis.com/auth/chat.messages"
CHAT_SPACES_SCOPE = "https://www.googleapis.com/auth/chat.spaces"
# Google Sheets API scopes
SHEETS_READONLY_SCOPE = 'https://www.googleapis.com/auth/spreadsheets.readonly'
SHEETS_WRITE_SCOPE = 'https://www.googleapis.com/auth/spreadsheets'
SHEETS_READONLY_SCOPE = "https://www.googleapis.com/auth/spreadsheets.readonly"
SHEETS_WRITE_SCOPE = "https://www.googleapis.com/auth/spreadsheets"
# Google Forms API scopes
FORMS_BODY_SCOPE = 'https://www.googleapis.com/auth/forms.body'
FORMS_BODY_READONLY_SCOPE = 'https://www.googleapis.com/auth/forms.body.readonly'
FORMS_RESPONSES_READONLY_SCOPE = 'https://www.googleapis.com/auth/forms.responses.readonly'
FORMS_BODY_SCOPE = "https://www.googleapis.com/auth/forms.body"
FORMS_BODY_READONLY_SCOPE = "https://www.googleapis.com/auth/forms.body.readonly"
FORMS_RESPONSES_READONLY_SCOPE = (
"https://www.googleapis.com/auth/forms.responses.readonly"
)
# Google Slides API scopes
SLIDES_SCOPE = 'https://www.googleapis.com/auth/presentations'
SLIDES_READONLY_SCOPE = 'https://www.googleapis.com/auth/presentations.readonly'
SLIDES_SCOPE = "https://www.googleapis.com/auth/presentations"
SLIDES_READONLY_SCOPE = "https://www.googleapis.com/auth/presentations.readonly"
# Google Tasks API scopes
TASKS_SCOPE = 'https://www.googleapis.com/auth/tasks'
TASKS_READONLY_SCOPE = 'https://www.googleapis.com/auth/tasks.readonly'
TASKS_SCOPE = "https://www.googleapis.com/auth/tasks"
TASKS_READONLY_SCOPE = "https://www.googleapis.com/auth/tasks.readonly"
# Google Custom Search API scope
CUSTOM_SEARCH_SCOPE = 'https://www.googleapis.com/auth/cse'
CUSTOM_SEARCH_SCOPE = "https://www.googleapis.com/auth/cse"
# Base OAuth scopes required for user identification
BASE_SCOPES = [
USERINFO_EMAIL_SCOPE,
USERINFO_PROFILE_SCOPE,
OPENID_SCOPE
]
BASE_SCOPES = [USERINFO_EMAIL_SCOPE, USERINFO_PROFILE_SCOPE, OPENID_SCOPE]
# Service-specific scope groups
DOCS_SCOPES = [
DOCS_READONLY_SCOPE,
DOCS_WRITE_SCOPE
]
DOCS_SCOPES = [DOCS_READONLY_SCOPE, DOCS_WRITE_SCOPE]
CALENDAR_SCOPES = [
CALENDAR_SCOPE,
CALENDAR_READONLY_SCOPE,
CALENDAR_EVENTS_SCOPE
]
CALENDAR_SCOPES = [CALENDAR_SCOPE, CALENDAR_READONLY_SCOPE, CALENDAR_EVENTS_SCOPE]
DRIVE_SCOPES = [
DRIVE_SCOPE,
DRIVE_READONLY_SCOPE,
DRIVE_FILE_SCOPE
]
DRIVE_SCOPES = [DRIVE_SCOPE, DRIVE_READONLY_SCOPE, DRIVE_FILE_SCOPE]
GMAIL_SCOPES = [
GMAIL_READONLY_SCOPE,
@@ -92,58 +80,44 @@ GMAIL_SCOPES = [
GMAIL_COMPOSE_SCOPE,
GMAIL_MODIFY_SCOPE,
GMAIL_LABELS_SCOPE,
GMAIL_SETTINGS_BASIC_SCOPE
GMAIL_SETTINGS_BASIC_SCOPE,
]
CHAT_SCOPES = [
CHAT_READONLY_SCOPE,
CHAT_WRITE_SCOPE,
CHAT_SPACES_SCOPE
]
CHAT_SCOPES = [CHAT_READONLY_SCOPE, CHAT_WRITE_SCOPE, CHAT_SPACES_SCOPE]
SHEETS_SCOPES = [
SHEETS_READONLY_SCOPE,
SHEETS_WRITE_SCOPE
]
SHEETS_SCOPES = [SHEETS_READONLY_SCOPE, SHEETS_WRITE_SCOPE]
FORMS_SCOPES = [
FORMS_BODY_SCOPE,
FORMS_BODY_READONLY_SCOPE,
FORMS_RESPONSES_READONLY_SCOPE
FORMS_RESPONSES_READONLY_SCOPE,
]
SLIDES_SCOPES = [
SLIDES_SCOPE,
SLIDES_READONLY_SCOPE
]
SLIDES_SCOPES = [SLIDES_SCOPE, SLIDES_READONLY_SCOPE]
TASKS_SCOPES = [
TASKS_SCOPE,
TASKS_READONLY_SCOPE
]
TASKS_SCOPES = [TASKS_SCOPE, TASKS_READONLY_SCOPE]
CUSTOM_SEARCH_SCOPES = [
CUSTOM_SEARCH_SCOPE
]
CUSTOM_SEARCH_SCOPES = [CUSTOM_SEARCH_SCOPE]
# Tool-to-scopes mapping
TOOL_SCOPES_MAP = {
'gmail': GMAIL_SCOPES,
'drive': DRIVE_SCOPES,
'calendar': CALENDAR_SCOPES,
'docs': DOCS_SCOPES,
'sheets': SHEETS_SCOPES,
'chat': CHAT_SCOPES,
'forms': FORMS_SCOPES,
'slides': SLIDES_SCOPES,
'tasks': TASKS_SCOPES,
'search': CUSTOM_SEARCH_SCOPES
"gmail": GMAIL_SCOPES,
"drive": DRIVE_SCOPES,
"calendar": CALENDAR_SCOPES,
"docs": DOCS_SCOPES,
"sheets": SHEETS_SCOPES,
"chat": CHAT_SCOPES,
"forms": FORMS_SCOPES,
"slides": SLIDES_SCOPES,
"tasks": TASKS_SCOPES,
"search": CUSTOM_SEARCH_SCOPES,
}
def set_enabled_tools(enabled_tools):
"""
Set the globally enabled tools list.
Args:
enabled_tools: List of enabled tool names.
"""
@@ -151,11 +125,12 @@ def set_enabled_tools(enabled_tools):
_ENABLED_TOOLS = enabled_tools
logger.info(f"Enabled tools set for scope management: {enabled_tools}")
def get_current_scopes():
"""
Returns scopes for currently enabled tools.
Uses globally set enabled tools or all tools if not set.
Returns:
List of unique scopes for the enabled tools plus base scopes.
"""
@@ -163,43 +138,47 @@ def get_current_scopes():
if enabled_tools is None:
# Default behavior - return all scopes
enabled_tools = TOOL_SCOPES_MAP.keys()
# Start with base scopes (always required)
scopes = BASE_SCOPES.copy()
# Add scopes for each enabled tool
for tool in enabled_tools:
if tool in TOOL_SCOPES_MAP:
scopes.extend(TOOL_SCOPES_MAP[tool])
logger.debug(f"Generated scopes for tools {list(enabled_tools)}: {len(set(scopes))} unique scopes")
logger.debug(
f"Generated scopes for tools {list(enabled_tools)}: {len(set(scopes))} unique scopes"
)
# Return unique scopes
return list(set(scopes))
def get_scopes_for_tools(enabled_tools=None):
"""
Returns scopes for enabled tools only.
Args:
enabled_tools: List of enabled tool names. If None, returns all scopes.
Returns:
List of unique scopes for the enabled tools plus base scopes.
"""
if enabled_tools is None:
# Default behavior - return all scopes
enabled_tools = TOOL_SCOPES_MAP.keys()
# Start with base scopes (always required)
scopes = BASE_SCOPES.copy()
# Add scopes for each enabled tool
for tool in enabled_tools:
if tool in TOOL_SCOPES_MAP:
scopes.extend(TOOL_SCOPES_MAP[tool])
# Return unique scopes
return list(set(scopes))
# Combined scopes for all supported Google Workspace operations (backwards compatibility)
SCOPES = get_scopes_for_tools()
SCOPES = get_scopes_for_tools()

View File

@@ -46,6 +46,7 @@ from auth.scopes import (
logger = logging.getLogger(__name__)
# Authentication helper functions
def _get_auth_context(
tool_name: str,
@@ -136,7 +137,9 @@ def _override_oauth21_user_email(
Returns:
Tuple of (updated_user_email, updated_args)
"""
if not (use_oauth21 and authenticated_user and current_user_email != authenticated_user):
if not (
use_oauth21 and authenticated_user and current_user_email != authenticated_user
):
return current_user_email, args
service_suffix = f" for service '{service_type}'" if service_type else ""
@@ -235,7 +238,9 @@ async def get_authenticated_google_service_oauth21(
f"Authenticated account {token_email} does not match requested user {user_google_email}."
)
credentials = ensure_session_from_access_token(access_token, resolved_email, session_id)
credentials = ensure_session_from_access_token(
access_token, resolved_email, session_id
)
if not credentials:
raise GoogleAuthenticationError(
"Unable to build Google credentials from authenticated access token."
@@ -286,7 +291,9 @@ async def get_authenticated_google_service_oauth21(
return service, user_google_email
def _extract_oauth21_user_email(authenticated_user: Optional[str], func_name: str) -> str:
def _extract_oauth21_user_email(
authenticated_user: Optional[str], func_name: str
) -> str:
"""
Extract user email for OAuth 2.1 mode.
@@ -308,9 +315,7 @@ def _extract_oauth21_user_email(authenticated_user: Optional[str], func_name: st
def _extract_oauth20_user_email(
args: tuple,
kwargs: dict,
wrapper_sig: inspect.Signature
args: tuple, kwargs: dict, wrapper_sig: inspect.Signature
) -> str:
"""
Extract user email for OAuth 2.0 mode from function arguments.
@@ -331,13 +336,10 @@ def _extract_oauth20_user_email(
user_google_email = bound_args.arguments.get("user_google_email")
if not user_google_email:
raise Exception(
"'user_google_email' parameter is required but was not found."
)
raise Exception("'user_google_email' parameter is required but was not found.")
return user_google_email
def _remove_user_email_arg_from_docstring(docstring: str) -> str:
"""
Remove user_google_email parameter documentation from docstring.
@@ -357,19 +359,20 @@ def _remove_user_email_arg_from_docstring(docstring: str) -> str:
# - user_google_email: Description
# - user_google_email (str) - Description
patterns = [
r'^\s*user_google_email\s*\([^)]*\)\s*:\s*[^\n]*\.?\s*(?:Required\.?)?\s*\n',
r'^\s*user_google_email\s*:\s*[^\n]*\n',
r'^\s*user_google_email\s*\([^)]*\)\s*-\s*[^\n]*\n',
r"^\s*user_google_email\s*\([^)]*\)\s*:\s*[^\n]*\.?\s*(?:Required\.?)?\s*\n",
r"^\s*user_google_email\s*:\s*[^\n]*\n",
r"^\s*user_google_email\s*\([^)]*\)\s*-\s*[^\n]*\n",
]
modified_docstring = docstring
for pattern in patterns:
modified_docstring = re.sub(pattern, '', modified_docstring, flags=re.MULTILINE)
modified_docstring = re.sub(pattern, "", modified_docstring, flags=re.MULTILINE)
# Clean up any sequence of 3 or more newlines that might have been created
modified_docstring = re.sub(r'\n{3,}', '\n\n', modified_docstring)
modified_docstring = re.sub(r"\n{3,}", "\n\n", modified_docstring)
return modified_docstring
# Service configuration mapping
SERVICE_CONFIGS = {
"gmail": {"service": "gmail", "version": "v1"},
@@ -425,7 +428,6 @@ SCOPE_GROUPS = {
}
def _resolve_scopes(scopes: Union[str, List[str]]) -> List[str]:
"""Resolve scope names to actual scope URLs."""
if isinstance(scopes, str):
@@ -467,7 +469,6 @@ def _handle_token_refresh_error(
f"Token expired or revoked for user {user_email} accessing {service_name}"
)
service_display_name = f"Google {service_name.title()}"
return (
@@ -527,10 +528,7 @@ def require_google_service(
# In OAuth 2.1 mode, also exclude 'user_google_email' since it's automatically determined.
if is_oauth21_enabled():
# Remove both 'service' and 'user_google_email' parameters
filtered_params = [
p for p in params[1:]
if p.name != 'user_google_email'
]
filtered_params = [p for p in params[1:] if p.name != "user_google_email"]
wrapper_sig = original_sig.replace(parameters=filtered_params)
else:
# Only remove 'service' parameter for OAuth 2.0 mode
@@ -548,9 +546,13 @@ def require_google_service(
# Extract user_google_email based on OAuth mode
if is_oauth21_enabled():
user_google_email = _extract_oauth21_user_email(authenticated_user, func.__name__)
user_google_email = _extract_oauth21_user_email(
authenticated_user, func.__name__
)
else:
user_google_email = _extract_oauth20_user_email(args, kwargs, wrapper_sig)
user_google_email = _extract_oauth20_user_email(
args, kwargs, wrapper_sig
)
# Get service configuration from the decorator's arguments
if service_type not in SERVICE_CONFIGS:
@@ -628,7 +630,9 @@ def require_google_service(
# Conditionally modify docstring to remove user_google_email parameter documentation
if is_oauth21_enabled():
logger.debug('OAuth 2.1 mode enabled, removing user_google_email from docstring')
logger.debug(
"OAuth 2.1 mode enabled, removing user_google_email from docstring"
)
if func.__doc__:
wrapper.__doc__ = _remove_user_email_arg_from_docstring(func.__doc__)
@@ -664,14 +668,10 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
params = list(original_sig.parameters.values())
# Remove injected service params from the wrapper signature; drop user_google_email only for OAuth 2.1.
filtered_params = [
p for p in params
if p.name not in service_param_names
]
filtered_params = [p for p in params if p.name not in service_param_names]
if is_oauth21_enabled():
filtered_params = [
p for p in filtered_params
if p.name != 'user_google_email'
p for p in filtered_params if p.name != "user_google_email"
]
wrapper_sig = original_sig.replace(parameters=filtered_params)
@@ -685,9 +685,13 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
# Extract user_google_email based on OAuth mode
if is_oauth21_enabled():
user_google_email = _extract_oauth21_user_email(authenticated_user, tool_name)
user_google_email = _extract_oauth21_user_email(
authenticated_user, tool_name
)
else:
user_google_email = _extract_oauth20_user_email(args, kwargs, wrapper_sig)
user_google_email = _extract_oauth20_user_email(
args, kwargs, wrapper_sig
)
# Authenticate all services
for config in service_configs:
@@ -764,12 +768,12 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
# Conditionally modify docstring to remove user_google_email parameter documentation
if is_oauth21_enabled():
logger.debug('OAuth 2.1 mode enabled, removing user_google_email from docstring')
logger.debug(
"OAuth 2.1 mode enabled, removing user_google_email from docstring"
)
if func.__doc__:
wrapper.__doc__ = _remove_user_email_arg_from_docstring(func.__doc__)
return wrapper
return decorator