apply ruff formatting
This commit is contained in:
@@ -1 +1 @@
|
||||
# Make the auth directory a Python package
|
||||
# Make the auth directory a Python package
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Authentication middleware to populate context state with user information
|
||||
"""
|
||||
|
||||
import jwt
|
||||
import logging
|
||||
import os
|
||||
@@ -20,11 +21,11 @@ class AuthInfoMiddleware(Middleware):
|
||||
Middleware to extract authentication information from JWT tokens
|
||||
and populate the FastMCP context state for use in tools and prompts.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.auth_provider_type = "GoogleProvider"
|
||||
|
||||
|
||||
async def _process_request_for_auth(self, context: MiddlewareContext):
|
||||
"""Helper to extract, verify, and store auth info from a request."""
|
||||
if not context.fastmcp_context:
|
||||
@@ -42,69 +43,101 @@ class AuthInfoMiddleware(Middleware):
|
||||
headers = get_http_headers()
|
||||
if headers:
|
||||
logger.debug("Processing HTTP headers for authentication")
|
||||
|
||||
|
||||
# Get the Authorization header
|
||||
auth_header = headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token_str = auth_header[7:] # Remove "Bearer " prefix
|
||||
logger.debug("Found Bearer token")
|
||||
|
||||
|
||||
# For Google OAuth tokens (ya29.*), we need to verify them differently
|
||||
if token_str.startswith("ya29."):
|
||||
logger.debug("Detected Google OAuth access token format")
|
||||
|
||||
|
||||
# Verify the token to get user info
|
||||
from core.server import get_auth_provider
|
||||
|
||||
auth_provider = get_auth_provider()
|
||||
|
||||
|
||||
if auth_provider:
|
||||
try:
|
||||
# Verify the token
|
||||
verified_auth = await auth_provider.verify_token(token_str)
|
||||
verified_auth = await auth_provider.verify_token(
|
||||
token_str
|
||||
)
|
||||
if verified_auth:
|
||||
# Extract user info from verified token
|
||||
user_email = None
|
||||
if hasattr(verified_auth, 'claims'):
|
||||
if hasattr(verified_auth, "claims"):
|
||||
user_email = verified_auth.claims.get("email")
|
||||
|
||||
|
||||
# Get expires_at, defaulting to 1 hour from now if not available
|
||||
if hasattr(verified_auth, 'expires_at'):
|
||||
if hasattr(verified_auth, "expires_at"):
|
||||
expires_at = verified_auth.expires_at
|
||||
else:
|
||||
expires_at = int(time.time()) + 3600 # Default to 1 hour
|
||||
|
||||
expires_at = (
|
||||
int(time.time()) + 3600
|
||||
) # Default to 1 hour
|
||||
|
||||
# Get client_id from verified auth or use default
|
||||
client_id = getattr(verified_auth, 'client_id', None) or "google"
|
||||
|
||||
client_id = (
|
||||
getattr(verified_auth, "client_id", None)
|
||||
or "google"
|
||||
)
|
||||
|
||||
access_token = SimpleNamespace(
|
||||
token=token_str,
|
||||
client_id=client_id,
|
||||
scopes=verified_auth.scopes if hasattr(verified_auth, 'scopes') else [],
|
||||
scopes=verified_auth.scopes
|
||||
if hasattr(verified_auth, "scopes")
|
||||
else [],
|
||||
session_id=f"google_oauth_{token_str[:8]}",
|
||||
expires_at=expires_at,
|
||||
# Add other fields that might be needed
|
||||
sub=verified_auth.sub if hasattr(verified_auth, 'sub') else user_email,
|
||||
email=user_email
|
||||
sub=verified_auth.sub
|
||||
if hasattr(verified_auth, "sub")
|
||||
else user_email,
|
||||
email=user_email,
|
||||
)
|
||||
|
||||
|
||||
# Store in context state - this is the authoritative authentication state
|
||||
context.fastmcp_context.set_state("access_token", access_token)
|
||||
mcp_session_id = getattr(context.fastmcp_context, "session_id", None)
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token", access_token
|
||||
)
|
||||
mcp_session_id = getattr(
|
||||
context.fastmcp_context, "session_id", None
|
||||
)
|
||||
ensure_session_from_access_token(
|
||||
verified_auth,
|
||||
user_email,
|
||||
mcp_session_id,
|
||||
)
|
||||
context.fastmcp_context.set_state("access_token_obj", verified_auth)
|
||||
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
|
||||
context.fastmcp_context.set_state("token_type", "google_oauth")
|
||||
context.fastmcp_context.set_state("user_email", user_email)
|
||||
context.fastmcp_context.set_state("username", user_email)
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token_obj", verified_auth
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", self.auth_provider_type
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"token_type", "google_oauth"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"username", user_email
|
||||
)
|
||||
# Set the definitive authentication state
|
||||
context.fastmcp_context.set_state("authenticated_user_email", user_email)
|
||||
context.fastmcp_context.set_state("authenticated_via", "bearer_token")
|
||||
|
||||
logger.info(f"Authenticated via Google OAuth: {user_email}")
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_via", "bearer_token"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Authenticated via Google OAuth: {user_email}"
|
||||
)
|
||||
else:
|
||||
logger.error("Failed to verify Google OAuth token")
|
||||
# Don't set authenticated_user_email if verification failed
|
||||
@@ -113,18 +146,29 @@ class AuthInfoMiddleware(Middleware):
|
||||
# Still store the unverified token - service decorator will handle verification
|
||||
access_token = SimpleNamespace(
|
||||
token=token_str,
|
||||
client_id=os.getenv("GOOGLE_OAUTH_CLIENT_ID", "google"),
|
||||
client_id=os.getenv(
|
||||
"GOOGLE_OAUTH_CLIENT_ID", "google"
|
||||
),
|
||||
scopes=[],
|
||||
session_id=f"google_oauth_{token_str[:8]}",
|
||||
expires_at=int(time.time()) + 3600, # Default to 1 hour
|
||||
expires_at=int(time.time())
|
||||
+ 3600, # Default to 1 hour
|
||||
sub="unknown",
|
||||
email=""
|
||||
email="",
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token", access_token
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", self.auth_provider_type
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"token_type", "google_oauth"
|
||||
)
|
||||
context.fastmcp_context.set_state("access_token", access_token)
|
||||
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
|
||||
context.fastmcp_context.set_state("token_type", "google_oauth")
|
||||
else:
|
||||
logger.warning("No auth provider available to verify Google token")
|
||||
logger.warning(
|
||||
"No auth provider available to verify Google token"
|
||||
)
|
||||
# Store unverified token
|
||||
access_token = SimpleNamespace(
|
||||
token=token_str,
|
||||
@@ -133,51 +177,93 @@ class AuthInfoMiddleware(Middleware):
|
||||
session_id=f"google_oauth_{token_str[:8]}",
|
||||
expires_at=int(time.time()) + 3600, # Default to 1 hour
|
||||
sub="unknown",
|
||||
email=""
|
||||
email="",
|
||||
)
|
||||
context.fastmcp_context.set_state("access_token", access_token)
|
||||
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
|
||||
context.fastmcp_context.set_state("token_type", "google_oauth")
|
||||
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token", access_token
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", self.auth_provider_type
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"token_type", "google_oauth"
|
||||
)
|
||||
|
||||
else:
|
||||
# Decode JWT to get user info
|
||||
try:
|
||||
token_payload = jwt.decode(
|
||||
token_str,
|
||||
options={"verify_signature": False}
|
||||
token_str, options={"verify_signature": False}
|
||||
)
|
||||
logger.debug(f"JWT payload decoded: {list(token_payload.keys())}")
|
||||
|
||||
logger.debug(
|
||||
f"JWT payload decoded: {list(token_payload.keys())}"
|
||||
)
|
||||
|
||||
# Create an AccessToken-like object
|
||||
access_token = SimpleNamespace(
|
||||
token=token_str,
|
||||
client_id=token_payload.get("client_id", "unknown"),
|
||||
scopes=token_payload.get("scope", "").split() if token_payload.get("scope") else [],
|
||||
session_id=token_payload.get("sid", token_payload.get("jti", token_payload.get("session_id", "unknown"))),
|
||||
expires_at=token_payload.get("exp", 0)
|
||||
scopes=token_payload.get("scope", "").split()
|
||||
if token_payload.get("scope")
|
||||
else [],
|
||||
session_id=token_payload.get(
|
||||
"sid",
|
||||
token_payload.get(
|
||||
"jti",
|
||||
token_payload.get("session_id", "unknown"),
|
||||
),
|
||||
),
|
||||
expires_at=token_payload.get("exp", 0),
|
||||
)
|
||||
|
||||
|
||||
# Store in context state
|
||||
context.fastmcp_context.set_state("access_token", access_token)
|
||||
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token", access_token
|
||||
)
|
||||
|
||||
# Store additional user info
|
||||
context.fastmcp_context.set_state("user_id", token_payload.get("sub"))
|
||||
context.fastmcp_context.set_state("username", token_payload.get("username", token_payload.get("email")))
|
||||
context.fastmcp_context.set_state("name", token_payload.get("name"))
|
||||
context.fastmcp_context.set_state("auth_time", token_payload.get("auth_time"))
|
||||
context.fastmcp_context.set_state("issuer", token_payload.get("iss"))
|
||||
context.fastmcp_context.set_state("audience", token_payload.get("aud"))
|
||||
context.fastmcp_context.set_state("jti", token_payload.get("jti"))
|
||||
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
|
||||
|
||||
context.fastmcp_context.set_state(
|
||||
"user_id", token_payload.get("sub")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"username",
|
||||
token_payload.get(
|
||||
"username", token_payload.get("email")
|
||||
),
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"name", token_payload.get("name")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_time", token_payload.get("auth_time")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"issuer", token_payload.get("iss")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"audience", token_payload.get("aud")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"jti", token_payload.get("jti")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", self.auth_provider_type
|
||||
)
|
||||
|
||||
# Set the definitive authentication state for JWT tokens
|
||||
user_email = token_payload.get("email", token_payload.get("username"))
|
||||
user_email = token_payload.get(
|
||||
"email", token_payload.get("username")
|
||||
)
|
||||
if user_email:
|
||||
context.fastmcp_context.set_state("authenticated_user_email", user_email)
|
||||
context.fastmcp_context.set_state("authenticated_via", "jwt_token")
|
||||
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_via", "jwt_token"
|
||||
)
|
||||
|
||||
logger.debug("JWT token processed successfully")
|
||||
|
||||
|
||||
except jwt.DecodeError as e:
|
||||
logger.error(f"Failed to decode JWT: {e}")
|
||||
except Exception as e:
|
||||
@@ -185,44 +271,58 @@ class AuthInfoMiddleware(Middleware):
|
||||
else:
|
||||
logger.debug("No Bearer token in Authorization header")
|
||||
else:
|
||||
logger.debug("No HTTP headers available (might be using stdio transport)")
|
||||
logger.debug(
|
||||
"No HTTP headers available (might be using stdio transport)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get HTTP request: {e}")
|
||||
|
||||
|
||||
# After trying HTTP headers, check for other authentication methods
|
||||
# This consolidates all authentication logic in the middleware
|
||||
if not context.fastmcp_context.get_state("authenticated_user_email"):
|
||||
logger.debug("No authentication found via bearer token, checking other methods")
|
||||
|
||||
logger.debug(
|
||||
"No authentication found via bearer token, checking other methods"
|
||||
)
|
||||
|
||||
# Check transport mode
|
||||
from core.config import get_transport_mode
|
||||
|
||||
transport_mode = get_transport_mode()
|
||||
|
||||
|
||||
if transport_mode == "stdio":
|
||||
# In stdio mode, check if there's a session with credentials
|
||||
# This is ONLY safe in stdio mode because it's single-user
|
||||
logger.debug("Checking for stdio mode authentication")
|
||||
|
||||
|
||||
# Get the requested user from the context if available
|
||||
requested_user = None
|
||||
if hasattr(context, 'request') and hasattr(context.request, 'params'):
|
||||
requested_user = context.request.params.get('user_google_email')
|
||||
elif hasattr(context, 'arguments'):
|
||||
if hasattr(context, "request") and hasattr(context.request, "params"):
|
||||
requested_user = context.request.params.get("user_google_email")
|
||||
elif hasattr(context, "arguments"):
|
||||
# FastMCP may store arguments differently
|
||||
requested_user = context.arguments.get('user_google_email')
|
||||
|
||||
requested_user = context.arguments.get("user_google_email")
|
||||
|
||||
if requested_user:
|
||||
try:
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
|
||||
# Check if user has a recent session
|
||||
if store.has_session(requested_user):
|
||||
logger.debug(f"Using recent stdio session for {requested_user}")
|
||||
logger.debug(
|
||||
f"Using recent stdio session for {requested_user}"
|
||||
)
|
||||
# In stdio mode, we can trust the user has authenticated recently
|
||||
context.fastmcp_context.set_state("authenticated_user_email", requested_user)
|
||||
context.fastmcp_context.set_state("authenticated_via", "stdio_session")
|
||||
context.fastmcp_context.set_state("auth_provider_type", "oauth21_stdio")
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", requested_user
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_via", "stdio_session"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", "oauth21_stdio"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking stdio session: {e}")
|
||||
|
||||
@@ -230,73 +330,95 @@ class AuthInfoMiddleware(Middleware):
|
||||
if not context.fastmcp_context.get_state("authenticated_user_email"):
|
||||
try:
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
|
||||
store = get_oauth21_session_store()
|
||||
single_user = store.get_single_user_email()
|
||||
if single_user:
|
||||
logger.debug(
|
||||
f"Defaulting to single stdio OAuth session for {single_user}"
|
||||
)
|
||||
context.fastmcp_context.set_state("authenticated_user_email", single_user)
|
||||
context.fastmcp_context.set_state("authenticated_via", "stdio_single_session")
|
||||
context.fastmcp_context.set_state("auth_provider_type", "oauth21_stdio")
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", single_user
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_via", "stdio_single_session"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", "oauth21_stdio"
|
||||
)
|
||||
context.fastmcp_context.set_state("user_email", single_user)
|
||||
context.fastmcp_context.set_state("username", single_user)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error determining stdio single-user session: {e}")
|
||||
|
||||
logger.debug(
|
||||
f"Error determining stdio single-user session: {e}"
|
||||
)
|
||||
|
||||
# Check for MCP session binding
|
||||
if not context.fastmcp_context.get_state("authenticated_user_email") and hasattr(context.fastmcp_context, 'session_id'):
|
||||
if not context.fastmcp_context.get_state(
|
||||
"authenticated_user_email"
|
||||
) and hasattr(context.fastmcp_context, "session_id"):
|
||||
mcp_session_id = context.fastmcp_context.session_id
|
||||
if mcp_session_id:
|
||||
try:
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
|
||||
# Check if this MCP session is bound to a user
|
||||
bound_user = store.get_user_by_mcp_session(mcp_session_id)
|
||||
if bound_user:
|
||||
logger.debug(f"MCP session bound to {bound_user}")
|
||||
context.fastmcp_context.set_state("authenticated_user_email", bound_user)
|
||||
context.fastmcp_context.set_state("authenticated_via", "mcp_session_binding")
|
||||
context.fastmcp_context.set_state("auth_provider_type", "oauth21_session")
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", bound_user
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_via", "mcp_session_binding"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", "oauth21_session"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking MCP session binding: {e}")
|
||||
|
||||
|
||||
async def on_call_tool(self, context: MiddlewareContext, call_next):
|
||||
"""Extract auth info from token and set in context state"""
|
||||
logger.debug("Processing tool call authentication")
|
||||
|
||||
|
||||
try:
|
||||
await self._process_request_for_auth(context)
|
||||
|
||||
|
||||
logger.debug("Passing to next handler")
|
||||
result = await call_next(context)
|
||||
logger.debug("Handler completed")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Check if this is an authentication error - don't log traceback for these
|
||||
if "GoogleAuthenticationError" in str(type(e)) or "Access denied: Cannot retrieve credentials" in str(e):
|
||||
if "GoogleAuthenticationError" in str(
|
||||
type(e)
|
||||
) or "Access denied: Cannot retrieve credentials" in str(e):
|
||||
logger.info(f"Authentication check failed: {e}")
|
||||
else:
|
||||
logger.error(f"Error in on_call_tool middleware: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def on_get_prompt(self, context: MiddlewareContext, call_next):
|
||||
"""Extract auth info for prompt requests too"""
|
||||
logger.debug("Processing prompt authentication")
|
||||
|
||||
|
||||
try:
|
||||
await self._process_request_for_auth(context)
|
||||
|
||||
|
||||
logger.debug("Passing prompt to next handler")
|
||||
result = await call_next(context)
|
||||
logger.debug("Prompt handler completed")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Check if this is an authentication error - don't log traceback for these
|
||||
if "GoogleAuthenticationError" in str(type(e)) or "Access denied: Cannot retrieve credentials" in str(e):
|
||||
if "GoogleAuthenticationError" in str(
|
||||
type(e)
|
||||
) or "Access denied: Cannot retrieve credentials" in str(e):
|
||||
logger.info(f"Authentication check failed in prompt: {e}")
|
||||
else:
|
||||
logger.error(f"Error in on_get_prompt middleware: {e}", exc_info=True)
|
||||
|
||||
@@ -4,6 +4,7 @@ External OAuth Provider for Google Workspace MCP
|
||||
Extends FastMCP's GoogleProvider to support external OAuth flows where
|
||||
access tokens (ya29.*) are issued by external systems and need validation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
@@ -55,7 +56,7 @@ class ExternalOAuthProvider(GoogleProvider):
|
||||
token=token,
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=self._client_id,
|
||||
client_secret=self._client_secret
|
||||
client_secret=self._client_secret,
|
||||
)
|
||||
|
||||
# Validate token by calling userinfo API
|
||||
@@ -63,20 +64,27 @@ class ExternalOAuthProvider(GoogleProvider):
|
||||
|
||||
if user_info and user_info.get("email"):
|
||||
# Token is valid - create AccessToken object
|
||||
logger.info(f"Validated external access token for: {user_info['email']}")
|
||||
logger.info(
|
||||
f"Validated external access token for: {user_info['email']}"
|
||||
)
|
||||
|
||||
# Create a mock AccessToken that the middleware expects
|
||||
# This matches the structure that FastMCP's AccessToken would have
|
||||
from types import SimpleNamespace
|
||||
|
||||
scope_list = list(getattr(self, "required_scopes", []) or [])
|
||||
access_token = SimpleNamespace(
|
||||
token=token,
|
||||
scopes=scope_list,
|
||||
expires_at=int(time.time()) + 3600, # Default to 1-hour validity
|
||||
claims={"email": user_info["email"], "sub": user_info.get("id")},
|
||||
expires_at=int(time.time())
|
||||
+ 3600, # Default to 1-hour validity
|
||||
claims={
|
||||
"email": user_info["email"],
|
||||
"sub": user_info.get("id"),
|
||||
},
|
||||
client_id=self._client_id,
|
||||
email=user_info["email"],
|
||||
sub=user_info.get("id")
|
||||
sub=user_info.get("id"),
|
||||
)
|
||||
return access_token
|
||||
else:
|
||||
|
||||
@@ -15,7 +15,7 @@ from google.auth.transport.requests import Request
|
||||
from google.auth.exceptions import RefreshError
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
from auth.scopes import SCOPES, get_current_scopes # noqa
|
||||
from auth.scopes import SCOPES, get_current_scopes # noqa
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
from auth.credential_store import get_credential_store
|
||||
from auth.oauth_config import get_oauth_config, is_stateless_mode
|
||||
@@ -136,11 +136,15 @@ def save_credentials_to_session(session_id: str, credentials: Credentials):
|
||||
client_secret=credentials.client_secret,
|
||||
scopes=credentials.scopes,
|
||||
expiry=credentials.expiry,
|
||||
mcp_session_id=session_id
|
||||
mcp_session_id=session_id,
|
||||
)
|
||||
logger.debug(
|
||||
f"Credentials saved to OAuth21SessionStore for session_id: {session_id}, user: {user_email}"
|
||||
)
|
||||
logger.debug(f"Credentials saved to OAuth21SessionStore for session_id: {session_id}, user: {user_email}")
|
||||
else:
|
||||
logger.warning(f"Could not save credentials to session store - no user email found for session: {session_id}")
|
||||
logger.warning(
|
||||
f"Could not save credentials to session store - no user email found for session: {session_id}"
|
||||
)
|
||||
|
||||
|
||||
def load_credentials_from_session(session_id: str) -> Optional[Credentials]:
|
||||
@@ -359,7 +363,9 @@ async def start_auth_flow(
|
||||
try:
|
||||
session_id = get_fastmcp_session_id()
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not retrieve FastMCP session ID for state binding: {e}")
|
||||
logger.debug(
|
||||
f"Could not retrieve FastMCP session ID for state binding: {e}"
|
||||
)
|
||||
|
||||
store = get_oauth21_session_store()
|
||||
store.store_oauth_state(oauth_state, session_id=session_id)
|
||||
@@ -460,16 +466,16 @@ def handle_auth_callback(
|
||||
state_values = parse_qs(parsed_response.query).get("state")
|
||||
state = state_values[0] if state_values else None
|
||||
|
||||
state_info = store.validate_and_consume_oauth_state(state, session_id=session_id)
|
||||
state_info = store.validate_and_consume_oauth_state(
|
||||
state, session_id=session_id
|
||||
)
|
||||
logger.debug(
|
||||
"Validated OAuth callback state %s for session %s",
|
||||
(state[:8] if state else "<missing>"),
|
||||
state_info.get("session_id") or "<unknown>",
|
||||
)
|
||||
|
||||
flow = create_oauth_flow(
|
||||
scopes=scopes, redirect_uri=redirect_uri, state=state
|
||||
)
|
||||
flow = create_oauth_flow(scopes=scopes, redirect_uri=redirect_uri, state=state)
|
||||
|
||||
# Exchange the authorization code for credentials
|
||||
# Note: fetch_token will use the redirect_uri configured in the flow
|
||||
@@ -502,7 +508,7 @@ def handle_auth_callback(
|
||||
scopes=credentials.scopes,
|
||||
expiry=credentials.expiry,
|
||||
mcp_session_id=session_id,
|
||||
issuer="https://accounts.google.com" # Add issuer for Google tokens
|
||||
issuer="https://accounts.google.com", # Add issuer for Google tokens
|
||||
)
|
||||
|
||||
# If session_id is provided, also save to session cache for compatibility
|
||||
@@ -546,7 +552,9 @@ def get_credentials(
|
||||
# Try to get credentials by MCP session
|
||||
credentials = store.get_credentials_by_mcp_session(session_id)
|
||||
if credentials:
|
||||
logger.info(f"[get_credentials] Found OAuth 2.1 credentials for MCP session {session_id}")
|
||||
logger.info(
|
||||
f"[get_credentials] Found OAuth 2.1 credentials for MCP session {session_id}"
|
||||
)
|
||||
|
||||
# Check scopes
|
||||
if not all(scope in credentials.scopes for scope in required_scopes):
|
||||
@@ -562,7 +570,9 @@ def get_credentials(
|
||||
# Try to refresh
|
||||
try:
|
||||
credentials.refresh(Request())
|
||||
logger.info(f"[get_credentials] Refreshed OAuth 2.1 credentials for session {session_id}")
|
||||
logger.info(
|
||||
f"[get_credentials] Refreshed OAuth 2.1 credentials for session {session_id}"
|
||||
)
|
||||
# Update stored credentials
|
||||
user_email = store.get_user_by_mcp_session(session_id)
|
||||
if user_email:
|
||||
@@ -572,11 +582,13 @@ def get_credentials(
|
||||
refresh_token=credentials.refresh_token,
|
||||
scopes=credentials.scopes,
|
||||
expiry=credentials.expiry,
|
||||
mcp_session_id=session_id
|
||||
mcp_session_id=session_id,
|
||||
)
|
||||
return credentials
|
||||
except Exception as e:
|
||||
logger.error(f"[get_credentials] Failed to refresh OAuth 2.1 credentials: {e}")
|
||||
logger.error(
|
||||
f"[get_credentials] Failed to refresh OAuth 2.1 credentials: {e}"
|
||||
)
|
||||
return None
|
||||
except ImportError:
|
||||
pass # OAuth 2.1 store not available
|
||||
@@ -692,7 +704,9 @@ def get_credentials(
|
||||
credential_store = get_credential_store()
|
||||
credential_store.store_credential(user_google_email, credentials)
|
||||
else:
|
||||
logger.info(f"Skipping credential file save in stateless mode for {user_google_email}")
|
||||
logger.info(
|
||||
f"Skipping credential file save in stateless mode for {user_google_email}"
|
||||
)
|
||||
|
||||
# Also update OAuth21SessionStore
|
||||
store = get_oauth21_session_store()
|
||||
@@ -706,7 +720,7 @@ def get_credentials(
|
||||
scopes=credentials.scopes,
|
||||
expiry=credentials.expiry,
|
||||
mcp_session_id=session_id,
|
||||
issuer="https://accounts.google.com" # Add issuer for Google tokens
|
||||
issuer="https://accounts.google.com", # Add issuer for Google tokens
|
||||
)
|
||||
|
||||
if session_id: # Update session cache if it was the source or is active
|
||||
@@ -795,9 +809,13 @@ async def get_authenticated_google_service(
|
||||
# First try context variable (works in async context)
|
||||
session_id = get_fastmcp_session_id()
|
||||
if session_id:
|
||||
logger.debug(f"[{tool_name}] Got FastMCP session ID from context: {session_id}")
|
||||
logger.debug(
|
||||
f"[{tool_name}] Got FastMCP session ID from context: {session_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"[{tool_name}] Context variable returned None/empty session ID")
|
||||
logger.debug(
|
||||
f"[{tool_name}] Context variable returned None/empty session ID"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"[{tool_name}] Could not get FastMCP session from context: {e}"
|
||||
@@ -807,17 +825,25 @@ async def get_authenticated_google_service(
|
||||
if not session_id and get_fastmcp_context:
|
||||
try:
|
||||
fastmcp_ctx = get_fastmcp_context()
|
||||
if fastmcp_ctx and hasattr(fastmcp_ctx, 'session_id'):
|
||||
if fastmcp_ctx and hasattr(fastmcp_ctx, "session_id"):
|
||||
session_id = fastmcp_ctx.session_id
|
||||
logger.debug(f"[{tool_name}] Got FastMCP session ID directly: {session_id}")
|
||||
logger.debug(
|
||||
f"[{tool_name}] Got FastMCP session ID directly: {session_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"[{tool_name}] FastMCP context exists but no session_id attribute")
|
||||
logger.debug(
|
||||
f"[{tool_name}] FastMCP context exists but no session_id attribute"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"[{tool_name}] Could not get FastMCP context directly: {e}")
|
||||
logger.debug(
|
||||
f"[{tool_name}] Could not get FastMCP context directly: {e}"
|
||||
)
|
||||
|
||||
# Final fallback: log if we still don't have session_id
|
||||
if not session_id:
|
||||
logger.warning(f"[{tool_name}] Unable to obtain FastMCP session ID from any source")
|
||||
logger.warning(
|
||||
f"[{tool_name}] Unable to obtain FastMCP session ID from any source"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[{tool_name}] Attempting to get authenticated {service_name} service. Email: '{user_google_email}', Session: '{session_id}'"
|
||||
@@ -838,8 +864,12 @@ async def get_authenticated_google_service(
|
||||
)
|
||||
|
||||
if not credentials or not credentials.valid:
|
||||
logger.warning(f"[{tool_name}] No valid credentials. Email: '{user_google_email}'.")
|
||||
logger.info(f"[{tool_name}] Valid email '{user_google_email}' provided, initiating auth flow.")
|
||||
logger.warning(
|
||||
f"[{tool_name}] No valid credentials. Email: '{user_google_email}'."
|
||||
)
|
||||
logger.info(
|
||||
f"[{tool_name}] Valid email '{user_google_email}' provided, initiating auth flow."
|
||||
)
|
||||
|
||||
# Ensure OAuth callback is available
|
||||
from auth.oauth_callback_server import ensure_oauth_callback_available
|
||||
|
||||
@@ -26,24 +26,26 @@ class MCPSessionMiddleware(BaseHTTPMiddleware):
|
||||
Middleware that extracts session information from requests and makes it
|
||||
available to MCP tool functions via context variables.
|
||||
"""
|
||||
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
"""Process request and set session context."""
|
||||
|
||||
logger.debug(f"MCPSessionMiddleware processing request: {request.method} {request.url.path}")
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"MCPSessionMiddleware processing request: {request.method} {request.url.path}"
|
||||
)
|
||||
|
||||
# Skip non-MCP paths
|
||||
if not request.url.path.startswith("/mcp"):
|
||||
logger.debug(f"Skipping non-MCP path: {request.url.path}")
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
session_context = None
|
||||
|
||||
|
||||
try:
|
||||
# Extract session information
|
||||
headers = dict(request.headers)
|
||||
session_id = extract_session_from_headers(headers)
|
||||
|
||||
|
||||
# Try to get OAuth 2.1 auth context from FastMCP
|
||||
auth_context = None
|
||||
user_email = None
|
||||
@@ -52,28 +54,33 @@ class MCPSessionMiddleware(BaseHTTPMiddleware):
|
||||
if hasattr(request.state, "auth"):
|
||||
auth_context = request.state.auth
|
||||
# Extract user email from auth claims if available
|
||||
if hasattr(auth_context, 'claims') and auth_context.claims:
|
||||
user_email = auth_context.claims.get('email')
|
||||
|
||||
if hasattr(auth_context, "claims") and auth_context.claims:
|
||||
user_email = auth_context.claims.get("email")
|
||||
|
||||
# Check for FastMCP session ID (from streamable HTTP transport)
|
||||
if hasattr(request.state, "session_id"):
|
||||
mcp_session_id = request.state.session_id
|
||||
logger.debug(f"Found FastMCP session ID: {mcp_session_id}")
|
||||
|
||||
|
||||
# Also check Authorization header for bearer tokens
|
||||
auth_header = headers.get("authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer ") and not user_email:
|
||||
if (
|
||||
auth_header
|
||||
and auth_header.lower().startswith("bearer ")
|
||||
and not user_email
|
||||
):
|
||||
try:
|
||||
import jwt
|
||||
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
# Decode without verification to extract email
|
||||
claims = jwt.decode(token, options={"verify_signature": False})
|
||||
user_email = claims.get('email')
|
||||
user_email = claims.get("email")
|
||||
if user_email:
|
||||
logger.debug(f"Extracted user email from JWT: {user_email}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Build session context
|
||||
if session_id or auth_context or user_email or mcp_session_id:
|
||||
# Create session ID hierarchy: explicit session_id > Google user session > FastMCP session
|
||||
@@ -82,10 +89,11 @@ class MCPSessionMiddleware(BaseHTTPMiddleware):
|
||||
effective_session_id = f"google_{user_email}"
|
||||
elif not effective_session_id and mcp_session_id:
|
||||
effective_session_id = mcp_session_id
|
||||
|
||||
|
||||
session_context = SessionContext(
|
||||
session_id=effective_session_id,
|
||||
user_id=user_email or (auth_context.user_id if auth_context else None),
|
||||
user_id=user_email
|
||||
or (auth_context.user_id if auth_context else None),
|
||||
auth_context=auth_context,
|
||||
request=request,
|
||||
metadata={
|
||||
@@ -93,19 +101,19 @@ class MCPSessionMiddleware(BaseHTTPMiddleware):
|
||||
"method": request.method,
|
||||
"user_email": user_email,
|
||||
"mcp_session_id": mcp_session_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"MCP request with session: session_id={session_context.session_id}, "
|
||||
f"user_id={session_context.user_id}, path={request.url.path}"
|
||||
)
|
||||
|
||||
|
||||
# Process request with session context
|
||||
with SessionContextManager(session_context):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in MCP session middleware: {e}")
|
||||
# Continue without session context
|
||||
|
||||
@@ -34,7 +34,9 @@ def _normalize_expiry_to_naive_utc(expiry: Optional[Any]) -> Optional[datetime]:
|
||||
try:
|
||||
return expiry.astimezone(timezone.utc).replace(tzinfo=None)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.debug("Failed to normalize aware expiry; returning without tzinfo")
|
||||
logger.debug(
|
||||
"Failed to normalize aware expiry; returning without tzinfo"
|
||||
)
|
||||
return expiry.replace(tzinfo=None)
|
||||
return expiry # Already naive; assumed to represent UTC
|
||||
|
||||
@@ -51,15 +53,15 @@ def _normalize_expiry_to_naive_utc(expiry: Optional[Any]) -> Optional[datetime]:
|
||||
|
||||
|
||||
# Context variable to store the current session information
|
||||
_current_session_context: contextvars.ContextVar[Optional['SessionContext']] = contextvars.ContextVar(
|
||||
'current_session_context',
|
||||
default=None
|
||||
_current_session_context: contextvars.ContextVar[Optional["SessionContext"]] = (
|
||||
contextvars.ContextVar("current_session_context", default=None)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
"""Container for session-related information."""
|
||||
|
||||
session_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
auth_context: Optional[Any] = None
|
||||
@@ -81,7 +83,9 @@ def set_session_context(context: Optional[SessionContext]):
|
||||
"""
|
||||
_current_session_context.set(context)
|
||||
if context:
|
||||
logger.debug(f"Set session context: session_id={context.session_id}, user_id={context.user_id}")
|
||||
logger.debug(
|
||||
f"Set session context: session_id={context.session_id}, user_id={context.user_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug("Cleared session context")
|
||||
|
||||
@@ -161,6 +165,7 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
|
||||
# If no session found, create a temporary session ID from token hash
|
||||
# This allows header-based authentication to work with session context
|
||||
import hashlib
|
||||
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()[:8]
|
||||
return f"bearer_token_{token_hash}"
|
||||
|
||||
@@ -171,6 +176,7 @@ def extract_session_from_headers(headers: Dict[str, str]) -> Optional[str]:
|
||||
# OAuth21SessionStore - Main Session Management
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class OAuth21SessionStore:
|
||||
"""
|
||||
Global store for OAuth 2.1 authenticated sessions.
|
||||
@@ -185,8 +191,12 @@ class OAuth21SessionStore:
|
||||
|
||||
def __init__(self):
|
||||
self._sessions: Dict[str, Dict[str, Any]] = {}
|
||||
self._mcp_session_mapping: Dict[str, str] = {} # Maps FastMCP session ID -> user email
|
||||
self._session_auth_binding: Dict[str, str] = {} # Maps session ID -> authenticated user email (immutable)
|
||||
self._mcp_session_mapping: Dict[
|
||||
str, str
|
||||
] = {} # Maps FastMCP session ID -> user email
|
||||
self._session_auth_binding: Dict[
|
||||
str, str
|
||||
] = {} # Maps session ID -> authenticated user email (immutable)
|
||||
self._oauth_states: Dict[str, Dict[str, Any]] = {}
|
||||
self._lock = RLock()
|
||||
|
||||
@@ -258,7 +268,9 @@ class OAuth21SessionStore:
|
||||
state_info = self._oauth_states.get(state)
|
||||
|
||||
if not state_info:
|
||||
logger.error("SECURITY: OAuth callback received unknown or expired state")
|
||||
logger.error(
|
||||
"SECURITY: OAuth callback received unknown or expired state"
|
||||
)
|
||||
raise ValueError("Invalid or expired OAuth state parameter")
|
||||
|
||||
bound_session = state_info.get("session_id")
|
||||
@@ -332,16 +344,26 @@ class OAuth21SessionStore:
|
||||
# Create immutable session binding (first binding wins, cannot be changed)
|
||||
if mcp_session_id not in self._session_auth_binding:
|
||||
self._session_auth_binding[mcp_session_id] = user_email
|
||||
logger.info(f"Created immutable session binding: {mcp_session_id} -> {user_email}")
|
||||
logger.info(
|
||||
f"Created immutable session binding: {mcp_session_id} -> {user_email}"
|
||||
)
|
||||
elif self._session_auth_binding[mcp_session_id] != user_email:
|
||||
# Security: Attempt to bind session to different user
|
||||
logger.error(f"SECURITY: Attempt to rebind session {mcp_session_id} from {self._session_auth_binding[mcp_session_id]} to {user_email}")
|
||||
raise ValueError(f"Session {mcp_session_id} is already bound to a different user")
|
||||
logger.error(
|
||||
f"SECURITY: Attempt to rebind session {mcp_session_id} from {self._session_auth_binding[mcp_session_id]} to {user_email}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Session {mcp_session_id} is already bound to a different user"
|
||||
)
|
||||
|
||||
self._mcp_session_mapping[mcp_session_id] = user_email
|
||||
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id}, mcp_session_id: {mcp_session_id})")
|
||||
logger.info(
|
||||
f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id}, mcp_session_id: {mcp_session_id})"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id})")
|
||||
logger.info(
|
||||
f"Stored OAuth 2.1 session for {user_email} (session_id: {session_id})"
|
||||
)
|
||||
|
||||
# Also create binding for the OAuth session ID
|
||||
if session_id and session_id not in self._session_auth_binding:
|
||||
@@ -382,7 +404,9 @@ class OAuth21SessionStore:
|
||||
logger.error(f"Failed to create credentials for {user_email}: {e}")
|
||||
return None
|
||||
|
||||
def get_credentials_by_mcp_session(self, mcp_session_id: str) -> Optional[Credentials]:
|
||||
def get_credentials_by_mcp_session(
|
||||
self, mcp_session_id: str
|
||||
) -> Optional[Credentials]:
|
||||
"""
|
||||
Get Google credentials using FastMCP session ID.
|
||||
|
||||
@@ -407,7 +431,7 @@ class OAuth21SessionStore:
|
||||
requested_user_email: str,
|
||||
session_id: Optional[str] = None,
|
||||
auth_token_email: Optional[str] = None,
|
||||
allow_recent_auth: bool = False
|
||||
allow_recent_auth: bool = False,
|
||||
) -> Optional[Credentials]:
|
||||
"""
|
||||
Get Google credentials with session validation.
|
||||
@@ -466,6 +490,7 @@ class OAuth21SessionStore:
|
||||
# Check transport mode to ensure this is only used in stdio
|
||||
try:
|
||||
from core.config import get_transport_mode
|
||||
|
||||
transport_mode = get_transport_mode()
|
||||
if transport_mode != "stdio":
|
||||
logger.error(
|
||||
@@ -533,7 +558,9 @@ class OAuth21SessionStore:
|
||||
# Also remove from auth binding
|
||||
if mcp_session_id in self._session_auth_binding:
|
||||
del self._session_auth_binding[mcp_session_id]
|
||||
logger.info(f"Removed OAuth 2.1 session for {user_email} and MCP mapping for {mcp_session_id}")
|
||||
logger.info(
|
||||
f"Removed OAuth 2.1 session for {user_email} and MCP mapping for {mcp_session_id}"
|
||||
)
|
||||
|
||||
# Remove OAuth session binding if exists
|
||||
if session_id and session_id in self._session_auth_binding:
|
||||
@@ -612,7 +639,9 @@ def _resolve_client_credentials() -> Tuple[Optional[str], Optional[str]]:
|
||||
try:
|
||||
client_secret = secret_obj.get_secret_value() # type: ignore[call-arg]
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug(f"Failed to resolve client secret from provider: {exc}")
|
||||
logger.debug(
|
||||
f"Failed to resolve client secret from provider: {exc}"
|
||||
)
|
||||
elif isinstance(secret_obj, str):
|
||||
client_secret = secret_obj
|
||||
|
||||
@@ -629,7 +658,9 @@ def _resolve_client_credentials() -> Tuple[Optional[str], Optional[str]]:
|
||||
return client_id, client_secret
|
||||
|
||||
|
||||
def _build_credentials_from_provider(access_token: AccessToken) -> Optional[Credentials]:
|
||||
def _build_credentials_from_provider(
|
||||
access_token: AccessToken,
|
||||
) -> Optional[Credentials]:
|
||||
"""Construct Google credentials from the provider cache."""
|
||||
if not _auth_provider:
|
||||
return None
|
||||
@@ -640,10 +671,14 @@ def _build_credentials_from_provider(access_token: AccessToken) -> Optional[Cred
|
||||
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
|
||||
refresh_token_value = getattr(_auth_provider, "_access_to_refresh", {}).get(access_token.token)
|
||||
refresh_token_value = getattr(_auth_provider, "_access_to_refresh", {}).get(
|
||||
access_token.token
|
||||
)
|
||||
refresh_token_obj = None
|
||||
if refresh_token_value:
|
||||
refresh_token_obj = getattr(_auth_provider, "_refresh_tokens", {}).get(refresh_token_value)
|
||||
refresh_token_obj = getattr(_auth_provider, "_refresh_tokens", {}).get(
|
||||
refresh_token_value
|
||||
)
|
||||
|
||||
expiry = None
|
||||
expires_at = getattr(access_entry, "expires_at", None)
|
||||
@@ -730,7 +765,9 @@ def ensure_session_from_access_token(
|
||||
return credentials
|
||||
|
||||
|
||||
def get_credentials_from_token(access_token: str, user_email: Optional[str] = None) -> Optional[Credentials]:
|
||||
def get_credentials_from_token(
|
||||
access_token: str, user_email: Optional[str] = None
|
||||
) -> Optional[Credentials]:
|
||||
"""
|
||||
Convert a bearer token to Google credentials.
|
||||
|
||||
@@ -753,14 +790,18 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
|
||||
# If the FastMCP provider is managing tokens, sync from provider storage
|
||||
if _auth_provider:
|
||||
access_record = getattr(_auth_provider, "_access_tokens", {}).get(access_token)
|
||||
access_record = getattr(_auth_provider, "_access_tokens", {}).get(
|
||||
access_token
|
||||
)
|
||||
if access_record:
|
||||
logger.debug("Building credentials from FastMCP provider cache")
|
||||
return ensure_session_from_access_token(access_record, user_email)
|
||||
|
||||
# Otherwise, create minimal credentials with just the access token
|
||||
# Assume token is valid for 1 hour (typical for Google tokens)
|
||||
expiry = _normalize_expiry_to_naive_utc(datetime.now(timezone.utc) + timedelta(hours=1))
|
||||
expiry = _normalize_expiry_to_naive_utc(
|
||||
datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
)
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
|
||||
credentials = Credentials(
|
||||
@@ -770,7 +811,7 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scopes=None,
|
||||
expiry=expiry
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
logger.debug("Created fallback Google credentials from bearer token")
|
||||
@@ -781,7 +822,9 @@ def get_credentials_from_token(access_token: str, user_email: Optional[str] = No
|
||||
return None
|
||||
|
||||
|
||||
def store_token_session(token_response: dict, user_email: str, mcp_session_id: Optional[str] = None) -> str:
|
||||
def store_token_session(
|
||||
token_response: dict, user_email: str, mcp_session_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Store a token response in the session store.
|
||||
|
||||
@@ -802,9 +845,12 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
|
||||
if not mcp_session_id:
|
||||
try:
|
||||
from core.context import get_fastmcp_session_id
|
||||
|
||||
mcp_session_id = get_fastmcp_session_id()
|
||||
if mcp_session_id:
|
||||
logger.debug(f"Got FastMCP session ID from context: {mcp_session_id}")
|
||||
logger.debug(
|
||||
f"Got FastMCP session ID from context: {mcp_session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get FastMCP session from context: {e}")
|
||||
|
||||
@@ -815,7 +861,9 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
|
||||
client_id, client_secret = _resolve_client_credentials()
|
||||
scopes = token_response.get("scope", "")
|
||||
scopes_list = scopes.split() if scopes else None
|
||||
expiry = datetime.now(timezone.utc) + timedelta(seconds=token_response.get("expires_in", 3600))
|
||||
expiry = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=token_response.get("expires_in", 3600)
|
||||
)
|
||||
|
||||
store.store_session(
|
||||
user_email=user_email,
|
||||
@@ -832,7 +880,9 @@ def store_token_session(token_response: dict, user_email: str, mcp_session_id: O
|
||||
)
|
||||
|
||||
if mcp_session_id:
|
||||
logger.info(f"Stored token session for {user_email} with MCP session {mcp_session_id}")
|
||||
logger.info(
|
||||
f"Stored token session for {user_email} with MCP session {mcp_session_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Stored token session for {user_email}")
|
||||
|
||||
|
||||
@@ -17,13 +17,18 @@ from fastapi.responses import FileResponse, JSONResponse
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from auth.scopes import SCOPES, get_current_scopes # noqa
|
||||
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
|
||||
from auth.scopes import SCOPES, get_current_scopes # noqa
|
||||
from auth.oauth_responses import (
|
||||
create_error_response,
|
||||
create_success_response,
|
||||
create_server_error_response,
|
||||
)
|
||||
from auth.google_auth import handle_auth_callback, check_client_secrets
|
||||
from auth.oauth_config import get_oauth_redirect_uri
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MinimalOAuthServer:
|
||||
"""
|
||||
Minimal HTTP server for OAuth callbacks in stdio mode.
|
||||
@@ -59,7 +64,9 @@ class MinimalOAuthServer:
|
||||
return create_error_response(error_message)
|
||||
|
||||
if not code:
|
||||
error_message = "Authentication failed: No authorization code received from Google."
|
||||
error_message = (
|
||||
"Authentication failed: No authorization code received from Google."
|
||||
)
|
||||
logger.error(error_message)
|
||||
return create_error_response(error_message)
|
||||
|
||||
@@ -69,7 +76,9 @@ class MinimalOAuthServer:
|
||||
if error_message:
|
||||
return create_server_error_response(error_message)
|
||||
|
||||
logger.info(f"OAuth callback: Received code (state: {state}). Attempting to exchange for tokens.")
|
||||
logger.info(
|
||||
f"OAuth callback: Received code (state: {state}). Attempting to exchange for tokens."
|
||||
)
|
||||
|
||||
# Session ID tracking removed - not needed
|
||||
|
||||
@@ -79,16 +88,20 @@ class MinimalOAuthServer:
|
||||
scopes=get_current_scopes(),
|
||||
authorization_response=str(request.url),
|
||||
redirect_uri=redirect_uri,
|
||||
session_id=None
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
logger.info(f"OAuth callback: Successfully authenticated user: {verified_user_id} (state: {state}).")
|
||||
logger.info(
|
||||
f"OAuth callback: Successfully authenticated user: {verified_user_id} (state: {state})."
|
||||
)
|
||||
|
||||
# Return success page using shared template
|
||||
return create_success_response(verified_user_id)
|
||||
|
||||
except Exception as e:
|
||||
error_message_detail = f"Error processing OAuth callback (state: {state}): {str(e)}"
|
||||
error_message_detail = (
|
||||
f"Error processing OAuth callback (state: {state}): {str(e)}"
|
||||
)
|
||||
logger.error(error_message_detail, exc_info=True)
|
||||
return create_server_error_response(str(e))
|
||||
|
||||
@@ -101,24 +114,22 @@ class MinimalOAuthServer:
|
||||
"""Serve a stored attachment file."""
|
||||
storage = get_attachment_storage()
|
||||
metadata = storage.get_attachment_metadata(file_id)
|
||||
|
||||
|
||||
if not metadata:
|
||||
return JSONResponse(
|
||||
{"error": "Attachment not found or expired"},
|
||||
status_code=404
|
||||
{"error": "Attachment not found or expired"}, status_code=404
|
||||
)
|
||||
|
||||
|
||||
file_path = storage.get_attachment_path(file_id)
|
||||
if not file_path:
|
||||
return JSONResponse(
|
||||
{"error": "Attachment file not found"},
|
||||
status_code=404
|
||||
{"error": "Attachment file not found"}, status_code=404
|
||||
)
|
||||
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
filename=metadata["filename"],
|
||||
media_type=metadata["mime_type"]
|
||||
media_type=metadata["mime_type"],
|
||||
)
|
||||
|
||||
def start(self) -> tuple[bool, str]:
|
||||
@@ -136,9 +147,9 @@ class MinimalOAuthServer:
|
||||
# Extract hostname from base_uri (e.g., "http://localhost" -> "localhost")
|
||||
try:
|
||||
parsed_uri = urlparse(self.base_uri)
|
||||
hostname = parsed_uri.hostname or 'localhost'
|
||||
hostname = parsed_uri.hostname or "localhost"
|
||||
except Exception:
|
||||
hostname = 'localhost'
|
||||
hostname = "localhost"
|
||||
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
@@ -156,7 +167,7 @@ class MinimalOAuthServer:
|
||||
host=hostname,
|
||||
port=self.port,
|
||||
log_level="warning",
|
||||
access_log=False
|
||||
access_log=False,
|
||||
)
|
||||
self.server = uvicorn.Server(config)
|
||||
asyncio.run(self.server.serve())
|
||||
@@ -178,7 +189,9 @@ class MinimalOAuthServer:
|
||||
result = s.connect_ex((hostname, self.port))
|
||||
if result == 0:
|
||||
self.is_running = True
|
||||
logger.info(f"Minimal OAuth server started on {hostname}:{self.port}")
|
||||
logger.info(
|
||||
f"Minimal OAuth server started on {hostname}:{self.port}"
|
||||
)
|
||||
return True, ""
|
||||
except Exception:
|
||||
pass
|
||||
@@ -195,7 +208,7 @@ class MinimalOAuthServer:
|
||||
|
||||
try:
|
||||
if self.server:
|
||||
if hasattr(self.server, 'should_exit'):
|
||||
if hasattr(self.server, "should_exit"):
|
||||
self.server.should_exit = True
|
||||
|
||||
if self.server_thread and self.server_thread.is_alive():
|
||||
@@ -211,7 +224,10 @@ class MinimalOAuthServer:
|
||||
# Global instance for stdio mode
|
||||
_minimal_oauth_server: Optional[MinimalOAuthServer] = None
|
||||
|
||||
def ensure_oauth_callback_available(transport_mode: str = "stdio", port: int = 8000, base_uri: str = "http://localhost") -> tuple[bool, str]:
|
||||
|
||||
def ensure_oauth_callback_available(
|
||||
transport_mode: str = "stdio", port: int = 8000, base_uri: str = "http://localhost"
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Ensure OAuth callback endpoint is available for the given transport mode.
|
||||
|
||||
@@ -230,7 +246,9 @@ def ensure_oauth_callback_available(transport_mode: str = "stdio", port: int = 8
|
||||
|
||||
if transport_mode == "streamable-http":
|
||||
# In streamable-http mode, the main FastAPI server should handle callbacks
|
||||
logger.debug("Using existing FastAPI server for OAuth callbacks (streamable-http mode)")
|
||||
logger.debug(
|
||||
"Using existing FastAPI server for OAuth callbacks (streamable-http mode)"
|
||||
)
|
||||
return True, ""
|
||||
|
||||
elif transport_mode == "stdio":
|
||||
@@ -243,10 +261,14 @@ def ensure_oauth_callback_available(transport_mode: str = "stdio", port: int = 8
|
||||
logger.info("Starting minimal OAuth server for stdio mode")
|
||||
success, error_msg = _minimal_oauth_server.start()
|
||||
if success:
|
||||
logger.info(f"Minimal OAuth server successfully started on {base_uri}:{port}")
|
||||
logger.info(
|
||||
f"Minimal OAuth server successfully started on {base_uri}:{port}"
|
||||
)
|
||||
return True, ""
|
||||
else:
|
||||
logger.error(f"Failed to start minimal OAuth server on {base_uri}:{port}: {error_msg}")
|
||||
logger.error(
|
||||
f"Failed to start minimal OAuth server on {base_uri}:{port}: {error_msg}"
|
||||
)
|
||||
return False, error_msg
|
||||
else:
|
||||
logger.info("Minimal OAuth server is already running")
|
||||
@@ -257,6 +279,7 @@ def ensure_oauth_callback_available(transport_mode: str = "stdio", port: int = 8
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
|
||||
def cleanup_oauth_callback_server():
|
||||
"""Clean up the minimal OAuth server if it was started."""
|
||||
global _minimal_oauth_server
|
||||
|
||||
@@ -36,19 +36,31 @@ class OAuthConfig:
|
||||
self.client_secret = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||
|
||||
# OAuth 2.1 configuration
|
||||
self.oauth21_enabled = os.getenv("MCP_ENABLE_OAUTH21", "false").lower() == "true"
|
||||
self.oauth21_enabled = (
|
||||
os.getenv("MCP_ENABLE_OAUTH21", "false").lower() == "true"
|
||||
)
|
||||
self.pkce_required = self.oauth21_enabled # PKCE is mandatory in OAuth 2.1
|
||||
self.supported_code_challenge_methods = ["S256", "plain"] if not self.oauth21_enabled else ["S256"]
|
||||
self.supported_code_challenge_methods = (
|
||||
["S256", "plain"] if not self.oauth21_enabled else ["S256"]
|
||||
)
|
||||
|
||||
# External OAuth 2.1 provider configuration
|
||||
self.external_oauth21_provider = os.getenv("EXTERNAL_OAUTH21_PROVIDER", "false").lower() == "true"
|
||||
self.external_oauth21_provider = (
|
||||
os.getenv("EXTERNAL_OAUTH21_PROVIDER", "false").lower() == "true"
|
||||
)
|
||||
if self.external_oauth21_provider and not self.oauth21_enabled:
|
||||
raise ValueError("EXTERNAL_OAUTH21_PROVIDER requires MCP_ENABLE_OAUTH21=true")
|
||||
raise ValueError(
|
||||
"EXTERNAL_OAUTH21_PROVIDER requires MCP_ENABLE_OAUTH21=true"
|
||||
)
|
||||
|
||||
# Stateless mode configuration
|
||||
self.stateless_mode = os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
|
||||
self.stateless_mode = (
|
||||
os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
|
||||
)
|
||||
if self.stateless_mode and not self.oauth21_enabled:
|
||||
raise ValueError("WORKSPACE_MCP_STATELESS_MODE requires MCP_ENABLE_OAUTH21=true")
|
||||
raise ValueError(
|
||||
"WORKSPACE_MCP_STATELESS_MODE requires MCP_ENABLE_OAUTH21=true"
|
||||
)
|
||||
|
||||
# Transport mode (will be set at runtime)
|
||||
self._transport_mode = "stdio" # Default
|
||||
@@ -95,7 +107,12 @@ class OAuthConfig:
|
||||
# Don't set FASTMCP_SERVER_AUTH if using external OAuth provider
|
||||
# (external OAuth means protocol-level auth is disabled, only tool-level auth)
|
||||
if not self.external_oauth21_provider:
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH", "fastmcp.server.auth.providers.google.GoogleProvider" if self.oauth21_enabled else None)
|
||||
_set_if_absent(
|
||||
"FASTMCP_SERVER_AUTH",
|
||||
"fastmcp.server.auth.providers.google.GoogleProvider"
|
||||
if self.oauth21_enabled
|
||||
else None,
|
||||
)
|
||||
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_CLIENT_ID", self.client_id)
|
||||
_set_if_absent("FASTMCP_SERVER_AUTH_GOOGLE_CLIENT_SECRET", self.client_secret)
|
||||
@@ -135,11 +152,13 @@ class OAuthConfig:
|
||||
origins.append(self.base_url)
|
||||
|
||||
# VS Code and development origins
|
||||
origins.extend([
|
||||
"vscode-webview://",
|
||||
"https://vscode.dev",
|
||||
"https://github.dev",
|
||||
])
|
||||
origins.extend(
|
||||
[
|
||||
"vscode-webview://",
|
||||
"https://vscode.dev",
|
||||
"https://github.dev",
|
||||
]
|
||||
)
|
||||
|
||||
# Custom origins from environment
|
||||
custom_origins = os.getenv("OAUTH_ALLOWED_ORIGINS")
|
||||
@@ -266,6 +285,7 @@ class OAuthConfig:
|
||||
|
||||
# Use the structured type for cleaner detection logic
|
||||
from auth.oauth_types import OAuthVersionDetectionParams
|
||||
|
||||
params = OAuthVersionDetectionParams.from_request(request_params)
|
||||
|
||||
# Clear OAuth 2.1 indicator: PKCE is present
|
||||
@@ -278,6 +298,7 @@ class OAuthConfig:
|
||||
if authenticated_user:
|
||||
try:
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
|
||||
store = get_oauth21_session_store()
|
||||
if store.has_session(authenticated_user):
|
||||
return "oauth21"
|
||||
@@ -291,7 +312,9 @@ class OAuthConfig:
|
||||
# Default to OAuth 2.0 for maximum compatibility
|
||||
return "oauth20"
|
||||
|
||||
def get_authorization_server_metadata(self, scopes: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
def get_authorization_server_metadata(
|
||||
self, scopes: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get OAuth authorization server metadata per RFC 8414.
|
||||
|
||||
@@ -311,7 +334,10 @@ class OAuthConfig:
|
||||
"userinfo_endpoint": "https://openidconnect.googleapis.com/v1/userinfo",
|
||||
"response_types_supported": ["code", "token"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
||||
"token_endpoint_auth_methods_supported": [
|
||||
"client_secret_post",
|
||||
"client_secret_basic",
|
||||
],
|
||||
"code_challenge_methods_supported": self.supported_code_challenge_methods,
|
||||
}
|
||||
|
||||
|
||||
@@ -220,4 +220,4 @@ def create_server_error_response(error_detail: str) -> HTMLResponse:
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return HTMLResponse(content=content, status_code=500)
|
||||
return HTMLResponse(content=content, status_code=500)
|
||||
|
||||
@@ -13,10 +13,11 @@ from typing import Optional, List, Dict, Any
|
||||
class OAuth21ServiceRequest:
|
||||
"""
|
||||
Encapsulates parameters for OAuth 2.1 service authentication requests.
|
||||
|
||||
|
||||
This parameter object pattern reduces function complexity and makes
|
||||
it easier to extend authentication parameters in the future.
|
||||
"""
|
||||
|
||||
service_name: str
|
||||
version: str
|
||||
tool_name: str
|
||||
@@ -26,7 +27,7 @@ class OAuth21ServiceRequest:
|
||||
auth_token_email: Optional[str] = None
|
||||
allow_recent_auth: bool = False
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
def to_legacy_params(self) -> dict:
|
||||
"""Convert to legacy parameter format for backward compatibility."""
|
||||
return {
|
||||
@@ -42,10 +43,11 @@ class OAuth21ServiceRequest:
|
||||
class OAuthVersionDetectionParams:
|
||||
"""
|
||||
Parameters used for OAuth version detection.
|
||||
|
||||
|
||||
Encapsulates the various signals we use to determine
|
||||
whether a client supports OAuth 2.1 or needs OAuth 2.0.
|
||||
"""
|
||||
|
||||
client_id: Optional[str] = None
|
||||
client_secret: Optional[str] = None
|
||||
code_challenge: Optional[str] = None
|
||||
@@ -53,9 +55,11 @@ class OAuthVersionDetectionParams:
|
||||
code_verifier: Optional[str] = None
|
||||
authenticated_user: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, request_params: Dict[str, Any]) -> "OAuthVersionDetectionParams":
|
||||
def from_request(
|
||||
cls, request_params: Dict[str, Any]
|
||||
) -> "OAuthVersionDetectionParams":
|
||||
"""Create from raw request parameters."""
|
||||
return cls(
|
||||
client_id=request_params.get("client_id"),
|
||||
@@ -66,13 +70,13 @@ class OAuthVersionDetectionParams:
|
||||
authenticated_user=request_params.get("authenticated_user"),
|
||||
session_id=request_params.get("session_id"),
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def has_pkce(self) -> bool:
|
||||
"""Check if PKCE parameters are present."""
|
||||
return bool(self.code_challenge or self.code_verifier)
|
||||
|
||||
|
||||
@property
|
||||
def is_public_client(self) -> bool:
|
||||
"""Check if this appears to be a public client (no secret)."""
|
||||
return bool(self.client_id and not self.client_secret)
|
||||
return bool(self.client_id and not self.client_secret)
|
||||
|
||||
165
auth/scopes.py
165
auth/scopes.py
@@ -4,6 +4,7 @@ Google Workspace OAuth Scopes
|
||||
This module centralizes OAuth scope definitions for Google Workspace integration.
|
||||
Separated from service_decorator.py to avoid circular imports.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -12,79 +13,66 @@ logger = logging.getLogger(__name__)
|
||||
_ENABLED_TOOLS = None
|
||||
|
||||
# Individual OAuth Scope Constants
|
||||
USERINFO_EMAIL_SCOPE = 'https://www.googleapis.com/auth/userinfo.email'
|
||||
USERINFO_PROFILE_SCOPE = 'https://www.googleapis.com/auth/userinfo.profile'
|
||||
OPENID_SCOPE = 'openid'
|
||||
CALENDAR_SCOPE = 'https://www.googleapis.com/auth/calendar'
|
||||
CALENDAR_READONLY_SCOPE = 'https://www.googleapis.com/auth/calendar.readonly'
|
||||
CALENDAR_EVENTS_SCOPE = 'https://www.googleapis.com/auth/calendar.events'
|
||||
USERINFO_EMAIL_SCOPE = "https://www.googleapis.com/auth/userinfo.email"
|
||||
USERINFO_PROFILE_SCOPE = "https://www.googleapis.com/auth/userinfo.profile"
|
||||
OPENID_SCOPE = "openid"
|
||||
CALENDAR_SCOPE = "https://www.googleapis.com/auth/calendar"
|
||||
CALENDAR_READONLY_SCOPE = "https://www.googleapis.com/auth/calendar.readonly"
|
||||
CALENDAR_EVENTS_SCOPE = "https://www.googleapis.com/auth/calendar.events"
|
||||
|
||||
# Google Drive scopes
|
||||
DRIVE_SCOPE = 'https://www.googleapis.com/auth/drive'
|
||||
DRIVE_READONLY_SCOPE = 'https://www.googleapis.com/auth/drive.readonly'
|
||||
DRIVE_FILE_SCOPE = 'https://www.googleapis.com/auth/drive.file'
|
||||
DRIVE_SCOPE = "https://www.googleapis.com/auth/drive"
|
||||
DRIVE_READONLY_SCOPE = "https://www.googleapis.com/auth/drive.readonly"
|
||||
DRIVE_FILE_SCOPE = "https://www.googleapis.com/auth/drive.file"
|
||||
|
||||
# Google Docs scopes
|
||||
DOCS_READONLY_SCOPE = 'https://www.googleapis.com/auth/documents.readonly'
|
||||
DOCS_WRITE_SCOPE = 'https://www.googleapis.com/auth/documents'
|
||||
DOCS_READONLY_SCOPE = "https://www.googleapis.com/auth/documents.readonly"
|
||||
DOCS_WRITE_SCOPE = "https://www.googleapis.com/auth/documents"
|
||||
|
||||
# Gmail API scopes
|
||||
GMAIL_READONLY_SCOPE = 'https://www.googleapis.com/auth/gmail.readonly'
|
||||
GMAIL_SEND_SCOPE = 'https://www.googleapis.com/auth/gmail.send'
|
||||
GMAIL_COMPOSE_SCOPE = 'https://www.googleapis.com/auth/gmail.compose'
|
||||
GMAIL_MODIFY_SCOPE = 'https://www.googleapis.com/auth/gmail.modify'
|
||||
GMAIL_LABELS_SCOPE = 'https://www.googleapis.com/auth/gmail.labels'
|
||||
GMAIL_SETTINGS_BASIC_SCOPE = 'https://www.googleapis.com/auth/gmail.settings.basic'
|
||||
GMAIL_READONLY_SCOPE = "https://www.googleapis.com/auth/gmail.readonly"
|
||||
GMAIL_SEND_SCOPE = "https://www.googleapis.com/auth/gmail.send"
|
||||
GMAIL_COMPOSE_SCOPE = "https://www.googleapis.com/auth/gmail.compose"
|
||||
GMAIL_MODIFY_SCOPE = "https://www.googleapis.com/auth/gmail.modify"
|
||||
GMAIL_LABELS_SCOPE = "https://www.googleapis.com/auth/gmail.labels"
|
||||
GMAIL_SETTINGS_BASIC_SCOPE = "https://www.googleapis.com/auth/gmail.settings.basic"
|
||||
|
||||
# Google Chat API scopes
|
||||
CHAT_READONLY_SCOPE = 'https://www.googleapis.com/auth/chat.messages.readonly'
|
||||
CHAT_WRITE_SCOPE = 'https://www.googleapis.com/auth/chat.messages'
|
||||
CHAT_SPACES_SCOPE = 'https://www.googleapis.com/auth/chat.spaces'
|
||||
CHAT_READONLY_SCOPE = "https://www.googleapis.com/auth/chat.messages.readonly"
|
||||
CHAT_WRITE_SCOPE = "https://www.googleapis.com/auth/chat.messages"
|
||||
CHAT_SPACES_SCOPE = "https://www.googleapis.com/auth/chat.spaces"
|
||||
|
||||
# Google Sheets API scopes
|
||||
SHEETS_READONLY_SCOPE = 'https://www.googleapis.com/auth/spreadsheets.readonly'
|
||||
SHEETS_WRITE_SCOPE = 'https://www.googleapis.com/auth/spreadsheets'
|
||||
SHEETS_READONLY_SCOPE = "https://www.googleapis.com/auth/spreadsheets.readonly"
|
||||
SHEETS_WRITE_SCOPE = "https://www.googleapis.com/auth/spreadsheets"
|
||||
|
||||
# Google Forms API scopes
|
||||
FORMS_BODY_SCOPE = 'https://www.googleapis.com/auth/forms.body'
|
||||
FORMS_BODY_READONLY_SCOPE = 'https://www.googleapis.com/auth/forms.body.readonly'
|
||||
FORMS_RESPONSES_READONLY_SCOPE = 'https://www.googleapis.com/auth/forms.responses.readonly'
|
||||
FORMS_BODY_SCOPE = "https://www.googleapis.com/auth/forms.body"
|
||||
FORMS_BODY_READONLY_SCOPE = "https://www.googleapis.com/auth/forms.body.readonly"
|
||||
FORMS_RESPONSES_READONLY_SCOPE = (
|
||||
"https://www.googleapis.com/auth/forms.responses.readonly"
|
||||
)
|
||||
|
||||
# Google Slides API scopes
|
||||
SLIDES_SCOPE = 'https://www.googleapis.com/auth/presentations'
|
||||
SLIDES_READONLY_SCOPE = 'https://www.googleapis.com/auth/presentations.readonly'
|
||||
SLIDES_SCOPE = "https://www.googleapis.com/auth/presentations"
|
||||
SLIDES_READONLY_SCOPE = "https://www.googleapis.com/auth/presentations.readonly"
|
||||
|
||||
# Google Tasks API scopes
|
||||
TASKS_SCOPE = 'https://www.googleapis.com/auth/tasks'
|
||||
TASKS_READONLY_SCOPE = 'https://www.googleapis.com/auth/tasks.readonly'
|
||||
TASKS_SCOPE = "https://www.googleapis.com/auth/tasks"
|
||||
TASKS_READONLY_SCOPE = "https://www.googleapis.com/auth/tasks.readonly"
|
||||
|
||||
# Google Custom Search API scope
|
||||
CUSTOM_SEARCH_SCOPE = 'https://www.googleapis.com/auth/cse'
|
||||
CUSTOM_SEARCH_SCOPE = "https://www.googleapis.com/auth/cse"
|
||||
|
||||
# Base OAuth scopes required for user identification
|
||||
BASE_SCOPES = [
|
||||
USERINFO_EMAIL_SCOPE,
|
||||
USERINFO_PROFILE_SCOPE,
|
||||
OPENID_SCOPE
|
||||
]
|
||||
BASE_SCOPES = [USERINFO_EMAIL_SCOPE, USERINFO_PROFILE_SCOPE, OPENID_SCOPE]
|
||||
|
||||
# Service-specific scope groups
|
||||
DOCS_SCOPES = [
|
||||
DOCS_READONLY_SCOPE,
|
||||
DOCS_WRITE_SCOPE
|
||||
]
|
||||
DOCS_SCOPES = [DOCS_READONLY_SCOPE, DOCS_WRITE_SCOPE]
|
||||
|
||||
CALENDAR_SCOPES = [
|
||||
CALENDAR_SCOPE,
|
||||
CALENDAR_READONLY_SCOPE,
|
||||
CALENDAR_EVENTS_SCOPE
|
||||
]
|
||||
CALENDAR_SCOPES = [CALENDAR_SCOPE, CALENDAR_READONLY_SCOPE, CALENDAR_EVENTS_SCOPE]
|
||||
|
||||
DRIVE_SCOPES = [
|
||||
DRIVE_SCOPE,
|
||||
DRIVE_READONLY_SCOPE,
|
||||
DRIVE_FILE_SCOPE
|
||||
]
|
||||
DRIVE_SCOPES = [DRIVE_SCOPE, DRIVE_READONLY_SCOPE, DRIVE_FILE_SCOPE]
|
||||
|
||||
GMAIL_SCOPES = [
|
||||
GMAIL_READONLY_SCOPE,
|
||||
@@ -92,58 +80,44 @@ GMAIL_SCOPES = [
|
||||
GMAIL_COMPOSE_SCOPE,
|
||||
GMAIL_MODIFY_SCOPE,
|
||||
GMAIL_LABELS_SCOPE,
|
||||
GMAIL_SETTINGS_BASIC_SCOPE
|
||||
GMAIL_SETTINGS_BASIC_SCOPE,
|
||||
]
|
||||
|
||||
CHAT_SCOPES = [
|
||||
CHAT_READONLY_SCOPE,
|
||||
CHAT_WRITE_SCOPE,
|
||||
CHAT_SPACES_SCOPE
|
||||
]
|
||||
CHAT_SCOPES = [CHAT_READONLY_SCOPE, CHAT_WRITE_SCOPE, CHAT_SPACES_SCOPE]
|
||||
|
||||
SHEETS_SCOPES = [
|
||||
SHEETS_READONLY_SCOPE,
|
||||
SHEETS_WRITE_SCOPE
|
||||
]
|
||||
SHEETS_SCOPES = [SHEETS_READONLY_SCOPE, SHEETS_WRITE_SCOPE]
|
||||
|
||||
FORMS_SCOPES = [
|
||||
FORMS_BODY_SCOPE,
|
||||
FORMS_BODY_READONLY_SCOPE,
|
||||
FORMS_RESPONSES_READONLY_SCOPE
|
||||
FORMS_RESPONSES_READONLY_SCOPE,
|
||||
]
|
||||
|
||||
SLIDES_SCOPES = [
|
||||
SLIDES_SCOPE,
|
||||
SLIDES_READONLY_SCOPE
|
||||
]
|
||||
SLIDES_SCOPES = [SLIDES_SCOPE, SLIDES_READONLY_SCOPE]
|
||||
|
||||
TASKS_SCOPES = [
|
||||
TASKS_SCOPE,
|
||||
TASKS_READONLY_SCOPE
|
||||
]
|
||||
TASKS_SCOPES = [TASKS_SCOPE, TASKS_READONLY_SCOPE]
|
||||
|
||||
CUSTOM_SEARCH_SCOPES = [
|
||||
CUSTOM_SEARCH_SCOPE
|
||||
]
|
||||
CUSTOM_SEARCH_SCOPES = [CUSTOM_SEARCH_SCOPE]
|
||||
|
||||
# Tool-to-scopes mapping
|
||||
TOOL_SCOPES_MAP = {
|
||||
'gmail': GMAIL_SCOPES,
|
||||
'drive': DRIVE_SCOPES,
|
||||
'calendar': CALENDAR_SCOPES,
|
||||
'docs': DOCS_SCOPES,
|
||||
'sheets': SHEETS_SCOPES,
|
||||
'chat': CHAT_SCOPES,
|
||||
'forms': FORMS_SCOPES,
|
||||
'slides': SLIDES_SCOPES,
|
||||
'tasks': TASKS_SCOPES,
|
||||
'search': CUSTOM_SEARCH_SCOPES
|
||||
"gmail": GMAIL_SCOPES,
|
||||
"drive": DRIVE_SCOPES,
|
||||
"calendar": CALENDAR_SCOPES,
|
||||
"docs": DOCS_SCOPES,
|
||||
"sheets": SHEETS_SCOPES,
|
||||
"chat": CHAT_SCOPES,
|
||||
"forms": FORMS_SCOPES,
|
||||
"slides": SLIDES_SCOPES,
|
||||
"tasks": TASKS_SCOPES,
|
||||
"search": CUSTOM_SEARCH_SCOPES,
|
||||
}
|
||||
|
||||
|
||||
def set_enabled_tools(enabled_tools):
|
||||
"""
|
||||
Set the globally enabled tools list.
|
||||
|
||||
|
||||
Args:
|
||||
enabled_tools: List of enabled tool names.
|
||||
"""
|
||||
@@ -151,11 +125,12 @@ def set_enabled_tools(enabled_tools):
|
||||
_ENABLED_TOOLS = enabled_tools
|
||||
logger.info(f"Enabled tools set for scope management: {enabled_tools}")
|
||||
|
||||
|
||||
def get_current_scopes():
|
||||
"""
|
||||
Returns scopes for currently enabled tools.
|
||||
Uses globally set enabled tools or all tools if not set.
|
||||
|
||||
|
||||
Returns:
|
||||
List of unique scopes for the enabled tools plus base scopes.
|
||||
"""
|
||||
@@ -163,43 +138,47 @@ def get_current_scopes():
|
||||
if enabled_tools is None:
|
||||
# Default behavior - return all scopes
|
||||
enabled_tools = TOOL_SCOPES_MAP.keys()
|
||||
|
||||
|
||||
# Start with base scopes (always required)
|
||||
scopes = BASE_SCOPES.copy()
|
||||
|
||||
|
||||
# Add scopes for each enabled tool
|
||||
for tool in enabled_tools:
|
||||
if tool in TOOL_SCOPES_MAP:
|
||||
scopes.extend(TOOL_SCOPES_MAP[tool])
|
||||
|
||||
logger.debug(f"Generated scopes for tools {list(enabled_tools)}: {len(set(scopes))} unique scopes")
|
||||
|
||||
logger.debug(
|
||||
f"Generated scopes for tools {list(enabled_tools)}: {len(set(scopes))} unique scopes"
|
||||
)
|
||||
# Return unique scopes
|
||||
return list(set(scopes))
|
||||
|
||||
|
||||
def get_scopes_for_tools(enabled_tools=None):
|
||||
"""
|
||||
Returns scopes for enabled tools only.
|
||||
|
||||
|
||||
Args:
|
||||
enabled_tools: List of enabled tool names. If None, returns all scopes.
|
||||
|
||||
|
||||
Returns:
|
||||
List of unique scopes for the enabled tools plus base scopes.
|
||||
"""
|
||||
if enabled_tools is None:
|
||||
# Default behavior - return all scopes
|
||||
enabled_tools = TOOL_SCOPES_MAP.keys()
|
||||
|
||||
|
||||
# Start with base scopes (always required)
|
||||
scopes = BASE_SCOPES.copy()
|
||||
|
||||
|
||||
# Add scopes for each enabled tool
|
||||
for tool in enabled_tools:
|
||||
if tool in TOOL_SCOPES_MAP:
|
||||
scopes.extend(TOOL_SCOPES_MAP[tool])
|
||||
|
||||
|
||||
# Return unique scopes
|
||||
return list(set(scopes))
|
||||
|
||||
|
||||
# Combined scopes for all supported Google Workspace operations (backwards compatibility)
|
||||
SCOPES = get_scopes_for_tools()
|
||||
SCOPES = get_scopes_for_tools()
|
||||
|
||||
@@ -46,6 +46,7 @@ from auth.scopes import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Authentication helper functions
|
||||
def _get_auth_context(
|
||||
tool_name: str,
|
||||
@@ -136,7 +137,9 @@ def _override_oauth21_user_email(
|
||||
Returns:
|
||||
Tuple of (updated_user_email, updated_args)
|
||||
"""
|
||||
if not (use_oauth21 and authenticated_user and current_user_email != authenticated_user):
|
||||
if not (
|
||||
use_oauth21 and authenticated_user and current_user_email != authenticated_user
|
||||
):
|
||||
return current_user_email, args
|
||||
|
||||
service_suffix = f" for service '{service_type}'" if service_type else ""
|
||||
@@ -235,7 +238,9 @@ async def get_authenticated_google_service_oauth21(
|
||||
f"Authenticated account {token_email} does not match requested user {user_google_email}."
|
||||
)
|
||||
|
||||
credentials = ensure_session_from_access_token(access_token, resolved_email, session_id)
|
||||
credentials = ensure_session_from_access_token(
|
||||
access_token, resolved_email, session_id
|
||||
)
|
||||
if not credentials:
|
||||
raise GoogleAuthenticationError(
|
||||
"Unable to build Google credentials from authenticated access token."
|
||||
@@ -286,7 +291,9 @@ async def get_authenticated_google_service_oauth21(
|
||||
return service, user_google_email
|
||||
|
||||
|
||||
def _extract_oauth21_user_email(authenticated_user: Optional[str], func_name: str) -> str:
|
||||
def _extract_oauth21_user_email(
|
||||
authenticated_user: Optional[str], func_name: str
|
||||
) -> str:
|
||||
"""
|
||||
Extract user email for OAuth 2.1 mode.
|
||||
|
||||
@@ -308,9 +315,7 @@ def _extract_oauth21_user_email(authenticated_user: Optional[str], func_name: st
|
||||
|
||||
|
||||
def _extract_oauth20_user_email(
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
wrapper_sig: inspect.Signature
|
||||
args: tuple, kwargs: dict, wrapper_sig: inspect.Signature
|
||||
) -> str:
|
||||
"""
|
||||
Extract user email for OAuth 2.0 mode from function arguments.
|
||||
@@ -331,13 +336,10 @@ def _extract_oauth20_user_email(
|
||||
|
||||
user_google_email = bound_args.arguments.get("user_google_email")
|
||||
if not user_google_email:
|
||||
raise Exception(
|
||||
"'user_google_email' parameter is required but was not found."
|
||||
)
|
||||
raise Exception("'user_google_email' parameter is required but was not found.")
|
||||
return user_google_email
|
||||
|
||||
|
||||
|
||||
def _remove_user_email_arg_from_docstring(docstring: str) -> str:
|
||||
"""
|
||||
Remove user_google_email parameter documentation from docstring.
|
||||
@@ -357,19 +359,20 @@ def _remove_user_email_arg_from_docstring(docstring: str) -> str:
|
||||
# - user_google_email: Description
|
||||
# - user_google_email (str) - Description
|
||||
patterns = [
|
||||
r'^\s*user_google_email\s*\([^)]*\)\s*:\s*[^\n]*\.?\s*(?:Required\.?)?\s*\n',
|
||||
r'^\s*user_google_email\s*:\s*[^\n]*\n',
|
||||
r'^\s*user_google_email\s*\([^)]*\)\s*-\s*[^\n]*\n',
|
||||
r"^\s*user_google_email\s*\([^)]*\)\s*:\s*[^\n]*\.?\s*(?:Required\.?)?\s*\n",
|
||||
r"^\s*user_google_email\s*:\s*[^\n]*\n",
|
||||
r"^\s*user_google_email\s*\([^)]*\)\s*-\s*[^\n]*\n",
|
||||
]
|
||||
|
||||
modified_docstring = docstring
|
||||
for pattern in patterns:
|
||||
modified_docstring = re.sub(pattern, '', modified_docstring, flags=re.MULTILINE)
|
||||
modified_docstring = re.sub(pattern, "", modified_docstring, flags=re.MULTILINE)
|
||||
|
||||
# Clean up any sequence of 3 or more newlines that might have been created
|
||||
modified_docstring = re.sub(r'\n{3,}', '\n\n', modified_docstring)
|
||||
modified_docstring = re.sub(r"\n{3,}", "\n\n", modified_docstring)
|
||||
return modified_docstring
|
||||
|
||||
|
||||
# Service configuration mapping
|
||||
SERVICE_CONFIGS = {
|
||||
"gmail": {"service": "gmail", "version": "v1"},
|
||||
@@ -425,7 +428,6 @@ SCOPE_GROUPS = {
|
||||
}
|
||||
|
||||
|
||||
|
||||
def _resolve_scopes(scopes: Union[str, List[str]]) -> List[str]:
|
||||
"""Resolve scope names to actual scope URLs."""
|
||||
if isinstance(scopes, str):
|
||||
@@ -467,7 +469,6 @@ def _handle_token_refresh_error(
|
||||
f"Token expired or revoked for user {user_email} accessing {service_name}"
|
||||
)
|
||||
|
||||
|
||||
service_display_name = f"Google {service_name.title()}"
|
||||
|
||||
return (
|
||||
@@ -527,10 +528,7 @@ def require_google_service(
|
||||
# In OAuth 2.1 mode, also exclude 'user_google_email' since it's automatically determined.
|
||||
if is_oauth21_enabled():
|
||||
# Remove both 'service' and 'user_google_email' parameters
|
||||
filtered_params = [
|
||||
p for p in params[1:]
|
||||
if p.name != 'user_google_email'
|
||||
]
|
||||
filtered_params = [p for p in params[1:] if p.name != "user_google_email"]
|
||||
wrapper_sig = original_sig.replace(parameters=filtered_params)
|
||||
else:
|
||||
# Only remove 'service' parameter for OAuth 2.0 mode
|
||||
@@ -548,9 +546,13 @@ def require_google_service(
|
||||
|
||||
# Extract user_google_email based on OAuth mode
|
||||
if is_oauth21_enabled():
|
||||
user_google_email = _extract_oauth21_user_email(authenticated_user, func.__name__)
|
||||
user_google_email = _extract_oauth21_user_email(
|
||||
authenticated_user, func.__name__
|
||||
)
|
||||
else:
|
||||
user_google_email = _extract_oauth20_user_email(args, kwargs, wrapper_sig)
|
||||
user_google_email = _extract_oauth20_user_email(
|
||||
args, kwargs, wrapper_sig
|
||||
)
|
||||
|
||||
# Get service configuration from the decorator's arguments
|
||||
if service_type not in SERVICE_CONFIGS:
|
||||
@@ -628,7 +630,9 @@ def require_google_service(
|
||||
|
||||
# Conditionally modify docstring to remove user_google_email parameter documentation
|
||||
if is_oauth21_enabled():
|
||||
logger.debug('OAuth 2.1 mode enabled, removing user_google_email from docstring')
|
||||
logger.debug(
|
||||
"OAuth 2.1 mode enabled, removing user_google_email from docstring"
|
||||
)
|
||||
if func.__doc__:
|
||||
wrapper.__doc__ = _remove_user_email_arg_from_docstring(func.__doc__)
|
||||
|
||||
@@ -664,14 +668,10 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
|
||||
params = list(original_sig.parameters.values())
|
||||
|
||||
# Remove injected service params from the wrapper signature; drop user_google_email only for OAuth 2.1.
|
||||
filtered_params = [
|
||||
p for p in params
|
||||
if p.name not in service_param_names
|
||||
]
|
||||
filtered_params = [p for p in params if p.name not in service_param_names]
|
||||
if is_oauth21_enabled():
|
||||
filtered_params = [
|
||||
p for p in filtered_params
|
||||
if p.name != 'user_google_email'
|
||||
p for p in filtered_params if p.name != "user_google_email"
|
||||
]
|
||||
|
||||
wrapper_sig = original_sig.replace(parameters=filtered_params)
|
||||
@@ -685,9 +685,13 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
|
||||
|
||||
# Extract user_google_email based on OAuth mode
|
||||
if is_oauth21_enabled():
|
||||
user_google_email = _extract_oauth21_user_email(authenticated_user, tool_name)
|
||||
user_google_email = _extract_oauth21_user_email(
|
||||
authenticated_user, tool_name
|
||||
)
|
||||
else:
|
||||
user_google_email = _extract_oauth20_user_email(args, kwargs, wrapper_sig)
|
||||
user_google_email = _extract_oauth20_user_email(
|
||||
args, kwargs, wrapper_sig
|
||||
)
|
||||
|
||||
# Authenticate all services
|
||||
for config in service_configs:
|
||||
@@ -764,12 +768,12 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
|
||||
|
||||
# Conditionally modify docstring to remove user_google_email parameter documentation
|
||||
if is_oauth21_enabled():
|
||||
logger.debug('OAuth 2.1 mode enabled, removing user_google_email from docstring')
|
||||
logger.debug(
|
||||
"OAuth 2.1 mode enabled, removing user_google_email from docstring"
|
||||
)
|
||||
if func.__doc__:
|
||||
wrapper.__doc__ = _remove_user_email_arg_from_docstring(func.__doc__)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user