v3 auth middleware fix

This commit is contained in:
Taylor Wilsdon
2026-02-13 10:20:39 -05:00
parent a3107e900b
commit 0075e8338f
7 changed files with 1086 additions and 1228 deletions

View File

@@ -4,6 +4,7 @@ Authentication middleware to populate context state with user information
import logging
import time
from fastmcp.server.middleware import Middleware, MiddlewareContext
from fastmcp.server.dependencies import get_access_token
from fastmcp.server.dependencies import get_http_headers
@@ -50,13 +51,15 @@ class AuthInfoMiddleware(Middleware):
logger.info(
f"✓ Using FastMCP validated token for user: {user_email}"
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"authenticated_user_email", user_email
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"authenticated_via", "fastmcp_oauth"
)
context.fastmcp_context.set_state("access_token", access_token)
await context.fastmcp_context.set_state(
"access_token", access_token, serializable=False
)
authenticated_user = user_email
auth_via = "fastmcp_oauth"
else:
@@ -146,8 +149,10 @@ class AuthInfoMiddleware(Middleware):
)
# Store in context state - this is the authoritative authentication state
context.fastmcp_context.set_state(
"access_token", access_token
await context.fastmcp_context.set_state(
"access_token",
access_token,
serializable=False,
)
mcp_session_id = getattr(
context.fastmcp_context, "session_id", None
@@ -157,24 +162,24 @@ class AuthInfoMiddleware(Middleware):
user_email,
mcp_session_id,
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"auth_provider_type",
self.auth_provider_type,
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"token_type", "google_oauth"
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"user_email", user_email
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"username", user_email
)
# Set the definitive authentication state
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"authenticated_user_email", user_email
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"authenticated_via", "bearer_token"
)
authenticated_user = user_email
@@ -244,13 +249,13 @@ class AuthInfoMiddleware(Middleware):
f"Using recent stdio session for {requested_user}"
)
# In stdio mode, we can trust the user has authenticated recently
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"authenticated_user_email", requested_user
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"authenticated_via", "stdio_session"
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"auth_provider_type", "oauth21_stdio"
)
authenticated_user = requested_user
@@ -269,17 +274,21 @@ class AuthInfoMiddleware(Middleware):
logger.debug(
f"Defaulting to single stdio OAuth session for {single_user}"
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"authenticated_user_email", single_user
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"authenticated_via", "stdio_single_session"
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"auth_provider_type", "oauth21_stdio"
)
context.fastmcp_context.set_state("user_email", single_user)
context.fastmcp_context.set_state("username", single_user)
await context.fastmcp_context.set_state(
"user_email", single_user
)
await context.fastmcp_context.set_state(
"username", single_user
)
authenticated_user = single_user
auth_via = "stdio_single_session"
except Exception as e:
@@ -302,13 +311,13 @@ class AuthInfoMiddleware(Middleware):
bound_user = store.get_user_by_mcp_session(mcp_session_id)
if bound_user:
logger.debug(f"MCP session bound to {bound_user}")
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"authenticated_user_email", bound_user
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"authenticated_via", "mcp_session_binding"
)
context.fastmcp_context.set_state(
await context.fastmcp_context.set_state(
"auth_provider_type", "oauth21_session"
)
authenticated_user = bound_user
@@ -319,8 +328,11 @@ class AuthInfoMiddleware(Middleware):
# Single exit point with logging
if authenticated_user:
logger.info(f"✓ Authenticated via {auth_via}: {authenticated_user}")
auth_email = await context.fastmcp_context.get_state(
"authenticated_user_email"
)
logger.debug(
f"Context state after auth: authenticated_user_email={context.fastmcp_context.get_state('authenticated_user_email')}"
f"Context state after auth: authenticated_user_email={auth_email}"
)
async def on_call_tool(self, context: MiddlewareContext, call_next):

View File

@@ -59,7 +59,7 @@ logger = logging.getLogger(__name__)
# Authentication helper functions
def _get_auth_context(
async def _get_auth_context(
tool_name: str,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
@@ -73,8 +73,8 @@ def _get_auth_context(
if not ctx:
return None, None, None
authenticated_user = ctx.get_state("authenticated_user_email")
auth_method = ctx.get_state("authenticated_via")
authenticated_user = await ctx.get_state("authenticated_user_email")
auth_method = await ctx.get_state("authenticated_via")
mcp_session_id = ctx.session_id if hasattr(ctx, "session_id") else None
if mcp_session_id:
@@ -604,7 +604,7 @@ def require_google_service(
# which does not include 'service'.
# Get authentication context early to determine OAuth mode
authenticated_user, auth_method, mcp_session_id = _get_auth_context(
authenticated_user, auth_method, mcp_session_id = await _get_auth_context(
func.__name__
)
@@ -751,7 +751,7 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
async def wrapper(*args, **kwargs):
# Get authentication context early
tool_name = func.__name__
authenticated_user, _, mcp_session_id = _get_auth_context(tool_name)
authenticated_user, _, mcp_session_id = await _get_auth_context(tool_name)
# Extract user_google_email based on OAuth mode
if is_oauth21_enabled():