v3 auth middleware fix
This commit is contained in:
@@ -4,6 +4,7 @@ Authentication middleware to populate context state with user information
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
@@ -50,13 +51,15 @@ class AuthInfoMiddleware(Middleware):
|
||||
logger.info(
|
||||
f"✓ Using FastMCP validated token for user: {user_email}"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"authenticated_via", "fastmcp_oauth"
|
||||
)
|
||||
context.fastmcp_context.set_state("access_token", access_token)
|
||||
await context.fastmcp_context.set_state(
|
||||
"access_token", access_token, serializable=False
|
||||
)
|
||||
authenticated_user = user_email
|
||||
auth_via = "fastmcp_oauth"
|
||||
else:
|
||||
@@ -146,8 +149,10 @@ class AuthInfoMiddleware(Middleware):
|
||||
)
|
||||
|
||||
# Store in context state - this is the authoritative authentication state
|
||||
context.fastmcp_context.set_state(
|
||||
"access_token", access_token
|
||||
await context.fastmcp_context.set_state(
|
||||
"access_token",
|
||||
access_token,
|
||||
serializable=False,
|
||||
)
|
||||
mcp_session_id = getattr(
|
||||
context.fastmcp_context, "session_id", None
|
||||
@@ -157,24 +162,24 @@ class AuthInfoMiddleware(Middleware):
|
||||
user_email,
|
||||
mcp_session_id,
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"auth_provider_type",
|
||||
self.auth_provider_type,
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"token_type", "google_oauth"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"username", user_email
|
||||
)
|
||||
# Set the definitive authentication state
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", user_email
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"authenticated_via", "bearer_token"
|
||||
)
|
||||
authenticated_user = user_email
|
||||
@@ -244,13 +249,13 @@ class AuthInfoMiddleware(Middleware):
|
||||
f"Using recent stdio session for {requested_user}"
|
||||
)
|
||||
# In stdio mode, we can trust the user has authenticated recently
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", requested_user
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"authenticated_via", "stdio_session"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"auth_provider_type", "oauth21_stdio"
|
||||
)
|
||||
authenticated_user = requested_user
|
||||
@@ -269,17 +274,21 @@ class AuthInfoMiddleware(Middleware):
|
||||
logger.debug(
|
||||
f"Defaulting to single stdio OAuth session for {single_user}"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", single_user
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"authenticated_via", "stdio_single_session"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"auth_provider_type", "oauth21_stdio"
|
||||
)
|
||||
context.fastmcp_context.set_state("user_email", single_user)
|
||||
context.fastmcp_context.set_state("username", single_user)
|
||||
await context.fastmcp_context.set_state(
|
||||
"user_email", single_user
|
||||
)
|
||||
await context.fastmcp_context.set_state(
|
||||
"username", single_user
|
||||
)
|
||||
authenticated_user = single_user
|
||||
auth_via = "stdio_single_session"
|
||||
except Exception as e:
|
||||
@@ -302,13 +311,13 @@ class AuthInfoMiddleware(Middleware):
|
||||
bound_user = store.get_user_by_mcp_session(mcp_session_id)
|
||||
if bound_user:
|
||||
logger.debug(f"MCP session bound to {bound_user}")
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"authenticated_user_email", bound_user
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"authenticated_via", "mcp_session_binding"
|
||||
)
|
||||
context.fastmcp_context.set_state(
|
||||
await context.fastmcp_context.set_state(
|
||||
"auth_provider_type", "oauth21_session"
|
||||
)
|
||||
authenticated_user = bound_user
|
||||
@@ -319,8 +328,11 @@ class AuthInfoMiddleware(Middleware):
|
||||
# Single exit point with logging
|
||||
if authenticated_user:
|
||||
logger.info(f"✓ Authenticated via {auth_via}: {authenticated_user}")
|
||||
auth_email = await context.fastmcp_context.get_state(
|
||||
"authenticated_user_email"
|
||||
)
|
||||
logger.debug(
|
||||
f"Context state after auth: authenticated_user_email={context.fastmcp_context.get_state('authenticated_user_email')}"
|
||||
f"Context state after auth: authenticated_user_email={auth_email}"
|
||||
)
|
||||
|
||||
async def on_call_tool(self, context: MiddlewareContext, call_next):
|
||||
|
||||
@@ -59,7 +59,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Authentication helper functions
|
||||
def _get_auth_context(
|
||||
async def _get_auth_context(
|
||||
tool_name: str,
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
@@ -73,8 +73,8 @@ def _get_auth_context(
|
||||
if not ctx:
|
||||
return None, None, None
|
||||
|
||||
authenticated_user = ctx.get_state("authenticated_user_email")
|
||||
auth_method = ctx.get_state("authenticated_via")
|
||||
authenticated_user = await ctx.get_state("authenticated_user_email")
|
||||
auth_method = await ctx.get_state("authenticated_via")
|
||||
mcp_session_id = ctx.session_id if hasattr(ctx, "session_id") else None
|
||||
|
||||
if mcp_session_id:
|
||||
@@ -604,7 +604,7 @@ def require_google_service(
|
||||
# which does not include 'service'.
|
||||
|
||||
# Get authentication context early to determine OAuth mode
|
||||
authenticated_user, auth_method, mcp_session_id = _get_auth_context(
|
||||
authenticated_user, auth_method, mcp_session_id = await _get_auth_context(
|
||||
func.__name__
|
||||
)
|
||||
|
||||
@@ -751,7 +751,7 @@ def require_multiple_services(service_configs: List[Dict[str, Any]]):
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Get authentication context early
|
||||
tool_name = func.__name__
|
||||
authenticated_user, _, mcp_session_id = _get_auth_context(tool_name)
|
||||
authenticated_user, _, mcp_session_id = await _get_auth_context(tool_name)
|
||||
|
||||
# Extract user_google_email based on OAuth mode
|
||||
if is_oauth21_enabled():
|
||||
|
||||
Reference in New Issue
Block a user