Merge branch 'main' of github.com:taylorwilsdon/google_workspace_mcp into feat/contacts-api
This commit is contained in:
@@ -4,13 +4,13 @@ Authentication middleware to populate context state with user information
|
||||
|
||||
import jwt
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
from auth.oauth21_session_store import ensure_session_from_access_token
|
||||
from auth.oauth_types import WorkspaceAccessToken
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -32,254 +32,257 @@ class AuthInfoMiddleware(Middleware):
|
||||
logger.warning("No fastmcp_context available")
|
||||
return
|
||||
|
||||
# Return early if authentication state is already set
|
||||
if context.fastmcp_context.get_state("authenticated_user_email"):
|
||||
logger.info("Authentication state already set.")
|
||||
return
|
||||
authenticated_user = None
|
||||
auth_via = None
|
||||
|
||||
# First check if FastMCP has already validated an access token
|
||||
try:
|
||||
access_token = get_access_token()
|
||||
if access_token:
|
||||
logger.info(
|
||||
f"[AuthInfoMiddleware] FastMCP access_token found: {type(access_token)}"
|
||||
)
|
||||
user_email = getattr(access_token, "email", None)
|
||||
if not user_email and hasattr(access_token, "claims"):
|
||||
user_email = access_token.claims.get("email")
|
||||
|
||||
if user_email:
|
||||
logger.info(
|
||||
f"✓ Using FastMCP validated token for user: {user_email}"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_via", "fastmcp_oauth"
|
||||
)
|
||||
context.fastmcp_context.set_state("access_token", access_token)
|
||||
authenticated_user = user_email
|
||||
auth_via = "fastmcp_oauth"
|
||||
else:
|
||||
logger.warning(
|
||||
f"FastMCP access_token found but no email. Type: {type(access_token).__name__}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get FastMCP access_token: {e}")
|
||||
|
||||
# Try to get the HTTP request to extract Authorization header
|
||||
try:
|
||||
# Use the new FastMCP method to get HTTP headers
|
||||
headers = get_http_headers()
|
||||
if headers:
|
||||
logger.debug("Processing HTTP headers for authentication")
|
||||
if not authenticated_user:
|
||||
try:
|
||||
# Use the new FastMCP method to get HTTP headers
|
||||
headers = get_http_headers()
|
||||
logger.info(
|
||||
f"[AuthInfoMiddleware] get_http_headers() returned: {headers is not None}, keys: {list(headers.keys()) if headers else 'None'}"
|
||||
)
|
||||
if headers:
|
||||
logger.debug("Processing HTTP headers for authentication")
|
||||
|
||||
# Get the Authorization header
|
||||
auth_header = headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token_str = auth_header[7:] # Remove "Bearer " prefix
|
||||
logger.debug("Found Bearer token")
|
||||
# Get the Authorization header
|
||||
auth_header = headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token_str = auth_header[7:] # Remove "Bearer " prefix
|
||||
logger.info(f"Found Bearer token: {token_str[:20]}...")
|
||||
|
||||
# For Google OAuth tokens (ya29.*), we need to verify them differently
|
||||
if token_str.startswith("ya29."):
|
||||
logger.debug("Detected Google OAuth access token format")
|
||||
# For Google OAuth tokens (ya29.*), we need to verify them differently
|
||||
if token_str.startswith("ya29."):
|
||||
logger.debug("Detected Google OAuth access token format")
|
||||
|
||||
# Verify the token to get user info
|
||||
from core.server import get_auth_provider
|
||||
# Verify the token to get user info
|
||||
from core.server import get_auth_provider
|
||||
|
||||
auth_provider = get_auth_provider()
|
||||
auth_provider = get_auth_provider()
|
||||
|
||||
if auth_provider:
|
||||
try:
|
||||
# Verify the token
|
||||
verified_auth = await auth_provider.verify_token(
|
||||
token_str
|
||||
)
|
||||
if verified_auth:
|
||||
# Extract user info from verified token
|
||||
user_email = None
|
||||
if hasattr(verified_auth, "claims"):
|
||||
user_email = verified_auth.claims.get("email")
|
||||
if auth_provider:
|
||||
try:
|
||||
# Verify the token
|
||||
verified_auth = await auth_provider.verify_token(
|
||||
token_str
|
||||
)
|
||||
if verified_auth:
|
||||
# Extract user info from verified token
|
||||
user_email = None
|
||||
if hasattr(verified_auth, "claims"):
|
||||
user_email = verified_auth.claims.get(
|
||||
"email"
|
||||
)
|
||||
|
||||
# Get expires_at, defaulting to 1 hour from now if not available
|
||||
if hasattr(verified_auth, "expires_at"):
|
||||
expires_at = verified_auth.expires_at
|
||||
# Get expires_at, defaulting to 1 hour from now if not available
|
||||
if hasattr(verified_auth, "expires_at"):
|
||||
expires_at = verified_auth.expires_at
|
||||
else:
|
||||
expires_at = (
|
||||
int(time.time()) + 3600
|
||||
) # Default to 1 hour
|
||||
|
||||
# Get client_id from verified auth or use default
|
||||
client_id = (
|
||||
getattr(verified_auth, "client_id", None)
|
||||
or "google"
|
||||
)
|
||||
|
||||
access_token = WorkspaceAccessToken(
|
||||
token=token_str,
|
||||
client_id=client_id,
|
||||
scopes=verified_auth.scopes
|
||||
if hasattr(verified_auth, "scopes")
|
||||
else [],
|
||||
session_id=f"google_oauth_{token_str[:8]}",
|
||||
expires_at=expires_at,
|
||||
claims=getattr(verified_auth, "claims", {})
|
||||
or {},
|
||||
sub=verified_auth.sub
|
||||
if hasattr(verified_auth, "sub")
|
||||
else user_email,
|
||||
email=user_email,
|
||||
)
|
||||
|
||||
# Store in context state - this is the authoritative authentication state
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token", access_token
|
||||
)
|
||||
mcp_session_id = getattr(
|
||||
context.fastmcp_context, "session_id", None
|
||||
)
|
||||
ensure_session_from_access_token(
|
||||
verified_auth,
|
||||
user_email,
|
||||
mcp_session_id,
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token_obj", verified_auth
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type",
|
||||
self.auth_provider_type,
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"token_type", "google_oauth"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"username", user_email
|
||||
)
|
||||
# Set the definitive authentication state
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_via", "bearer_token"
|
||||
)
|
||||
authenticated_user = user_email
|
||||
auth_via = "bearer_token"
|
||||
else:
|
||||
expires_at = (
|
||||
int(time.time()) + 3600
|
||||
) # Default to 1 hour
|
||||
|
||||
# Get client_id from verified auth or use default
|
||||
client_id = (
|
||||
getattr(verified_auth, "client_id", None)
|
||||
or "google"
|
||||
logger.error(
|
||||
"Failed to verify Google OAuth token"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error verifying Google OAuth token: {e}"
|
||||
)
|
||||
|
||||
access_token = SimpleNamespace(
|
||||
token=token_str,
|
||||
client_id=client_id,
|
||||
scopes=verified_auth.scopes
|
||||
if hasattr(verified_auth, "scopes")
|
||||
else [],
|
||||
session_id=f"google_oauth_{token_str[:8]}",
|
||||
expires_at=expires_at,
|
||||
# Add other fields that might be needed
|
||||
sub=verified_auth.sub
|
||||
if hasattr(verified_auth, "sub")
|
||||
else user_email,
|
||||
email=user_email,
|
||||
)
|
||||
|
||||
# Store in context state - this is the authoritative authentication state
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token", access_token
|
||||
)
|
||||
mcp_session_id = getattr(
|
||||
context.fastmcp_context, "session_id", None
|
||||
)
|
||||
ensure_session_from_access_token(
|
||||
verified_auth,
|
||||
user_email,
|
||||
mcp_session_id,
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token_obj", verified_auth
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", self.auth_provider_type
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"token_type", "google_oauth"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"username", user_email
|
||||
)
|
||||
# Set the definitive authentication state
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_via", "bearer_token"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Authenticated via Google OAuth: {user_email}"
|
||||
)
|
||||
else:
|
||||
logger.error("Failed to verify Google OAuth token")
|
||||
# Don't set authenticated_user_email if verification failed
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying Google OAuth token: {e}")
|
||||
# Still store the unverified token - service decorator will handle verification
|
||||
access_token = SimpleNamespace(
|
||||
token=token_str,
|
||||
client_id=os.getenv(
|
||||
"GOOGLE_OAUTH_CLIENT_ID", "google"
|
||||
),
|
||||
scopes=[],
|
||||
session_id=f"google_oauth_{token_str[:8]}",
|
||||
expires_at=int(time.time())
|
||||
+ 3600, # Default to 1 hour
|
||||
sub="unknown",
|
||||
email="",
|
||||
else:
|
||||
logger.warning(
|
||||
"No auth provider available to verify Google token"
|
||||
)
|
||||
|
||||
else:
|
||||
# Decode JWT to get user info
|
||||
logger.info("Processing JWT token")
|
||||
try:
|
||||
token_payload = jwt.decode(
|
||||
token_str, options={"verify_signature": False}
|
||||
)
|
||||
logger.info(
|
||||
f"JWT payload decoded: {list(token_payload.keys())}"
|
||||
)
|
||||
|
||||
# Create an AccessToken-like object
|
||||
access_token = WorkspaceAccessToken(
|
||||
token=token_str,
|
||||
client_id=token_payload.get("client_id", "unknown"),
|
||||
scopes=token_payload.get("scope", "").split()
|
||||
if token_payload.get("scope")
|
||||
else [],
|
||||
session_id=token_payload.get(
|
||||
"sid",
|
||||
token_payload.get(
|
||||
"jti",
|
||||
token_payload.get("session_id", "unknown"),
|
||||
),
|
||||
),
|
||||
expires_at=token_payload.get("exp", 0),
|
||||
claims=token_payload,
|
||||
sub=token_payload.get("sub"),
|
||||
email=token_payload.get("email"),
|
||||
)
|
||||
|
||||
# Store in context state
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token", access_token
|
||||
)
|
||||
|
||||
# Store additional user info
|
||||
context.fastmcp_context.set_state(
|
||||
"user_id", token_payload.get("sub")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"username",
|
||||
token_payload.get(
|
||||
"username", token_payload.get("email")
|
||||
),
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"name", token_payload.get("name")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_time", token_payload.get("auth_time")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"issuer", token_payload.get("iss")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"audience", token_payload.get("aud")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"jti", token_payload.get("jti")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", self.auth_provider_type
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"token_type", "google_oauth"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"No auth provider available to verify Google token"
|
||||
)
|
||||
# Store unverified token
|
||||
access_token = SimpleNamespace(
|
||||
token=token_str,
|
||||
client_id=os.getenv("GOOGLE_OAUTH_CLIENT_ID", "google"),
|
||||
scopes=[],
|
||||
session_id=f"google_oauth_{token_str[:8]}",
|
||||
expires_at=int(time.time()) + 3600, # Default to 1 hour
|
||||
sub="unknown",
|
||||
email="",
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token", access_token
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", self.auth_provider_type
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"token_type", "google_oauth"
|
||||
)
|
||||
|
||||
# Set the definitive authentication state for JWT tokens
|
||||
user_email = token_payload.get(
|
||||
"email", token_payload.get("username")
|
||||
)
|
||||
if user_email:
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_via", "jwt_token"
|
||||
)
|
||||
authenticated_user = user_email
|
||||
auth_via = "jwt_token"
|
||||
|
||||
except jwt.DecodeError:
|
||||
logger.error("Failed to decode JWT token")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing JWT: {type(e).__name__}"
|
||||
)
|
||||
else:
|
||||
# Decode JWT to get user info
|
||||
try:
|
||||
token_payload = jwt.decode(
|
||||
token_str, options={"verify_signature": False}
|
||||
)
|
||||
logger.debug(
|
||||
f"JWT payload decoded: {list(token_payload.keys())}"
|
||||
)
|
||||
|
||||
# Create an AccessToken-like object
|
||||
access_token = SimpleNamespace(
|
||||
token=token_str,
|
||||
client_id=token_payload.get("client_id", "unknown"),
|
||||
scopes=token_payload.get("scope", "").split()
|
||||
if token_payload.get("scope")
|
||||
else [],
|
||||
session_id=token_payload.get(
|
||||
"sid",
|
||||
token_payload.get(
|
||||
"jti",
|
||||
token_payload.get("session_id", "unknown"),
|
||||
),
|
||||
),
|
||||
expires_at=token_payload.get("exp", 0),
|
||||
)
|
||||
|
||||
# Store in context state
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token", access_token
|
||||
)
|
||||
|
||||
# Store additional user info
|
||||
context.fastmcp_context.set_state(
|
||||
"user_id", token_payload.get("sub")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"username",
|
||||
token_payload.get(
|
||||
"username", token_payload.get("email")
|
||||
),
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"name", token_payload.get("name")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_time", token_payload.get("auth_time")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"issuer", token_payload.get("iss")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"audience", token_payload.get("aud")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"jti", token_payload.get("jti")
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", self.auth_provider_type
|
||||
)
|
||||
|
||||
# Set the definitive authentication state for JWT tokens
|
||||
user_email = token_payload.get(
|
||||
"email", token_payload.get("username")
|
||||
)
|
||||
if user_email:
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_via", "jwt_token"
|
||||
)
|
||||
|
||||
logger.debug("JWT token processed successfully")
|
||||
|
||||
except jwt.DecodeError as e:
|
||||
logger.error(f"Failed to decode JWT: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing JWT: {e}")
|
||||
logger.debug("No Bearer token in Authorization header")
|
||||
else:
|
||||
logger.debug("No Bearer token in Authorization header")
|
||||
else:
|
||||
logger.debug(
|
||||
"No HTTP headers available (might be using stdio transport)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get HTTP request: {e}")
|
||||
logger.debug(
|
||||
"No HTTP headers available (might be using stdio transport)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get HTTP request: {e}")
|
||||
|
||||
# After trying HTTP headers, check for other authentication methods
|
||||
# This consolidates all authentication logic in the middleware
|
||||
if not context.fastmcp_context.get_state("authenticated_user_email"):
|
||||
if not authenticated_user:
|
||||
logger.debug(
|
||||
"No authentication found via bearer token, checking other methods"
|
||||
)
|
||||
@@ -323,11 +326,13 @@ class AuthInfoMiddleware(Middleware):
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", "oauth21_stdio"
|
||||
)
|
||||
authenticated_user = requested_user
|
||||
auth_via = "stdio_session"
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking stdio session: {e}")
|
||||
|
||||
# If no requested user was provided but exactly one session exists, assume it in stdio mode
|
||||
if not context.fastmcp_context.get_state("authenticated_user_email"):
|
||||
if not authenticated_user:
|
||||
try:
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
|
||||
@@ -348,15 +353,17 @@ class AuthInfoMiddleware(Middleware):
|
||||
)
|
||||
context.fastmcp_context.set_state("user_email", single_user)
|
||||
context.fastmcp_context.set_state("username", single_user)
|
||||
authenticated_user = single_user
|
||||
auth_via = "stdio_single_session"
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Error determining stdio single-user session: {e}"
|
||||
)
|
||||
|
||||
# Check for MCP session binding
|
||||
if not context.fastmcp_context.get_state(
|
||||
"authenticated_user_email"
|
||||
) and hasattr(context.fastmcp_context, "session_id"):
|
||||
if not authenticated_user and hasattr(
|
||||
context.fastmcp_context, "session_id"
|
||||
):
|
||||
mcp_session_id = context.fastmcp_context.session_id
|
||||
if mcp_session_id:
|
||||
try:
|
||||
@@ -377,9 +384,18 @@ class AuthInfoMiddleware(Middleware):
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", "oauth21_session"
|
||||
)
|
||||
authenticated_user = bound_user
|
||||
auth_via = "mcp_session_binding"
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking MCP session binding: {e}")
|
||||
|
||||
# Single exit point with logging
|
||||
if authenticated_user:
|
||||
logger.info(f"✓ Authenticated via {auth_via}: {authenticated_user}")
|
||||
logger.debug(
|
||||
f"Context state after auth: authenticated_user_email={context.fastmcp_context.get_state('authenticated_user_email')}"
|
||||
)
|
||||
|
||||
async def on_call_tool(self, context: MiddlewareContext, call_next):
|
||||
"""Extract auth info from token and set in context state"""
|
||||
logger.debug("Processing tool call authentication")
|
||||
|
||||
@@ -3,18 +3,27 @@ External OAuth Provider for Google Workspace MCP
|
||||
|
||||
Extends FastMCP's GoogleProvider to support external OAuth flows where
|
||||
access tokens (ya29.*) are issued by external systems and need validation.
|
||||
|
||||
This provider acts as a Resource Server only - it validates tokens issued by
|
||||
Google's Authorization Server but does not issue tokens itself.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from starlette.routing import Route
|
||||
from fastmcp.server.auth.providers.google import GoogleProvider
|
||||
from fastmcp.server.auth import AccessToken
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from auth.oauth_types import WorkspaceAccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Google's OAuth 2.0 Authorization Server
|
||||
GOOGLE_ISSUER_URL = "https://accounts.google.com"
|
||||
|
||||
|
||||
class ExternalOAuthProvider(GoogleProvider):
|
||||
"""
|
||||
@@ -22,14 +31,28 @@ class ExternalOAuthProvider(GoogleProvider):
|
||||
|
||||
This provider handles ya29.* access tokens by calling Google's userinfo API,
|
||||
while maintaining compatibility with standard JWT ID tokens.
|
||||
|
||||
Unlike the standard GoogleProvider, this acts as a Resource Server only:
|
||||
- Does NOT create /authorize, /token, /register endpoints
|
||||
- Only advertises Google's authorization server in metadata
|
||||
- Only validates tokens, does not issue them
|
||||
"""
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
resource_server_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize and store client credentials for token validation."""
|
||||
self._resource_server_url = resource_server_url
|
||||
super().__init__(client_id=client_id, client_secret=client_secret, **kwargs)
|
||||
# Store credentials as they're not exposed by parent class
|
||||
self._client_id = client_id
|
||||
self._client_secret = client_secret
|
||||
# Store as string - Pydantic validates it when passed to models
|
||||
self.resource_server_url = self._resource_server_url
|
||||
|
||||
async def verify_token(self, token: str) -> Optional[AccessToken]:
|
||||
"""
|
||||
@@ -68,12 +91,8 @@ class ExternalOAuthProvider(GoogleProvider):
|
||||
f"Validated external access token for: {user_info['email']}"
|
||||
)
|
||||
|
||||
# Create a mock AccessToken that the middleware expects
|
||||
# This matches the structure that FastMCP's AccessToken would have
|
||||
from types import SimpleNamespace
|
||||
|
||||
scope_list = list(getattr(self, "required_scopes", []) or [])
|
||||
access_token = SimpleNamespace(
|
||||
access_token = WorkspaceAccessToken(
|
||||
token=token,
|
||||
scopes=scope_list,
|
||||
expires_at=int(time.time())
|
||||
@@ -97,3 +116,40 @@ class ExternalOAuthProvider(GoogleProvider):
|
||||
|
||||
# For JWT tokens, use parent class implementation
|
||||
return await super().verify_token(token)
|
||||
|
||||
def get_routes(self, **kwargs) -> list[Route]:
|
||||
"""
|
||||
Get OAuth routes for external provider mode.
|
||||
|
||||
Returns only protected resource metadata routes that point to Google
|
||||
as the authorization server. Does not create authorization server routes
|
||||
(/authorize, /token, etc.) since tokens are issued by Google directly.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments passed by FastMCP (e.g., mcp_path)
|
||||
|
||||
Returns:
|
||||
List of routes - only protected resource metadata
|
||||
"""
|
||||
from mcp.server.auth.routes import create_protected_resource_routes
|
||||
|
||||
if not self.resource_server_url:
|
||||
logger.warning(
|
||||
"ExternalOAuthProvider: resource_server_url not set, no routes created"
|
||||
)
|
||||
return []
|
||||
|
||||
# Create protected resource routes that point to Google as the authorization server
|
||||
# Pass strings directly - Pydantic validates them during model construction
|
||||
protected_routes = create_protected_resource_routes(
|
||||
resource_url=self.resource_server_url,
|
||||
authorization_servers=[GOOGLE_ISSUER_URL],
|
||||
scopes_supported=self.required_scopes,
|
||||
resource_name="Google Workspace MCP",
|
||||
resource_documentation=None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"ExternalOAuthProvider: Created protected resource routes pointing to {GOOGLE_ISSUER_URL}"
|
||||
)
|
||||
return protected_routes
|
||||
|
||||
@@ -8,6 +8,16 @@ improving code maintainability and type safety.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from fastmcp.server.auth import AccessToken
|
||||
|
||||
|
||||
class WorkspaceAccessToken(AccessToken):
|
||||
"""AccessToken extended with workspace-specific fields."""
|
||||
|
||||
session_id: Optional[str] = None
|
||||
sub: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuth21ServiceRequest:
|
||||
|
||||
@@ -14,7 +14,11 @@ from auth.oauth21_session_store import (
|
||||
get_oauth21_session_store,
|
||||
ensure_session_from_access_token,
|
||||
)
|
||||
from auth.oauth_config import is_oauth21_enabled, get_oauth_config
|
||||
from auth.oauth_config import (
|
||||
is_oauth21_enabled,
|
||||
get_oauth_config,
|
||||
is_external_oauth21_provider,
|
||||
)
|
||||
from core.context import set_fastmcp_session_id
|
||||
from auth.scopes import (
|
||||
GMAIL_READONLY_SCOPE,
|
||||
@@ -75,8 +79,8 @@ def _get_auth_context(
|
||||
if mcp_session_id:
|
||||
set_fastmcp_session_id(mcp_session_id)
|
||||
|
||||
logger.debug(
|
||||
f"[{tool_name}] Auth from middleware: {authenticated_user} via {auth_method}"
|
||||
logger.info(
|
||||
f"[{tool_name}] Auth from middleware: authenticated_user={authenticated_user}, auth_method={auth_method}, session_id={mcp_session_id}"
|
||||
)
|
||||
return authenticated_user, auth_method, mcp_session_id
|
||||
|
||||
@@ -104,6 +108,19 @@ def _detect_oauth_version(
|
||||
)
|
||||
return True
|
||||
|
||||
# If FastMCP protocol-level auth is enabled, a validated access token should
|
||||
# be available even if middleware state wasn't populated.
|
||||
try:
|
||||
if get_access_token() is not None:
|
||||
logger.info(
|
||||
f"[{tool_name}] OAuth 2.1 mode: Using OAuth 2.1 based on validated access token"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"[{tool_name}] Could not inspect access token for OAuth mode: {e}"
|
||||
)
|
||||
|
||||
# Only use version detection for unauthenticated requests
|
||||
config = get_oauth_config()
|
||||
request_params = {}
|
||||
@@ -486,6 +503,26 @@ def _handle_token_refresh_error(
|
||||
)
|
||||
|
||||
service_display_name = f"Google {service_name.title()}"
|
||||
if is_oauth21_enabled():
|
||||
if is_external_oauth21_provider():
|
||||
oauth21_step = (
|
||||
"Provide a valid OAuth 2.1 bearer token in the Authorization header"
|
||||
)
|
||||
else:
|
||||
oauth21_step = "Sign in through your MCP client's OAuth 2.1 flow"
|
||||
|
||||
return (
|
||||
f"**Authentication Required: Token Expired/Revoked for {service_display_name}**\n\n"
|
||||
f"Your Google authentication token for {user_email} has expired or been revoked. "
|
||||
f"This commonly happens when:\n"
|
||||
f"- The token has been unused for an extended period\n"
|
||||
f"- You've changed your Google account password\n"
|
||||
f"- You've revoked access to the application\n\n"
|
||||
f"**To resolve this, please:**\n"
|
||||
f"1. {oauth21_step}\n"
|
||||
f"2. Retry your original command\n\n"
|
||||
f"The application will automatically use the new credentials once authentication is complete."
|
||||
)
|
||||
|
||||
return (
|
||||
f"**Authentication Required: Token Expired/Revoked for {service_display_name}**\n\n"
|
||||
@@ -503,6 +540,16 @@ def _handle_token_refresh_error(
|
||||
else:
|
||||
# Handle other types of refresh errors
|
||||
logger.error(f"Unexpected refresh error for user {user_email}: {error}")
|
||||
if is_oauth21_enabled():
|
||||
if is_external_oauth21_provider():
|
||||
return (
|
||||
f"Authentication error occurred for {user_email}. "
|
||||
"Please provide a valid OAuth 2.1 bearer token and retry."
|
||||
)
|
||||
return (
|
||||
f"Authentication error occurred for {user_email}. "
|
||||
"Please sign in via your MCP client's OAuth 2.1 flow and retry."
|
||||
)
|
||||
return (
|
||||
f"Authentication error occurred for {user_email}. "
|
||||
f"Please try running `start_google_auth` with your email and the appropriate service name to reauthenticate."
|
||||
@@ -639,7 +686,7 @@ def require_google_service(
|
||||
error_message = _handle_token_refresh_error(
|
||||
e, actual_user_email, service_name
|
||||
)
|
||||
raise Exception(error_message)
|
||||
raise GoogleAuthenticationError(error_message)
|
||||
|
||||
# Set the wrapper's signature to the one without 'service'
|
||||
wrapper.__signature__ = wrapper_sig
|
||||
@@ -777,7 +824,7 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
|
||||
error_message = _handle_token_refresh_error(
|
||||
e, user_google_email, "Multiple Services"
|
||||
)
|
||||
raise Exception(error_message)
|
||||
raise GoogleAuthenticationError(error_message)
|
||||
|
||||
# Set the wrapper's signature
|
||||
wrapper.__signature__ = wrapper_sig
|
||||
|
||||
Reference in New Issue
Block a user