simplify flow
This commit is contained in:
@@ -32,6 +32,9 @@ class AuthInfoMiddleware(Middleware):
|
||||
logger.warning("No fastmcp_context available")
|
||||
return
|
||||
|
||||
authenticated_user = None
|
||||
auth_via = None
|
||||
|
||||
# First check if FastMCP has already validated an access token
|
||||
try:
|
||||
access_token = get_access_token()
|
||||
@@ -54,7 +57,8 @@ class AuthInfoMiddleware(Middleware):
|
||||
"authenticated_via", "fastmcp_oauth"
|
||||
)
|
||||
context.fastmcp_context.set_state("access_token", access_token)
|
||||
return
|
||||
authenticated_user = user_email
|
||||
auth_via = "fastmcp_oauth"
|
||||
else:
|
||||
logger.warning(
|
||||
f"FastMCP access_token found but no email. Attributes: {dir(access_token)}"
|
||||
@@ -63,6 +67,7 @@ class AuthInfoMiddleware(Middleware):
|
||||
logger.debug(f"Could not get FastMCP access_token: {e}")
|
||||
|
||||
# Try to get the HTTP request to extract Authorization header
|
||||
if not authenticated_user:
|
||||
try:
|
||||
# Use the new FastMCP method to get HTTP headers
|
||||
headers = get_http_headers()
|
||||
@@ -162,14 +167,8 @@ class AuthInfoMiddleware(Middleware):
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_via", "bearer_token"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"✓ Authenticated via Google OAuth: {user_email}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Context state after auth: authenticated_user_email={context.fastmcp_context.get_state('authenticated_user_email')}"
|
||||
)
|
||||
return
|
||||
authenticated_user = user_email
|
||||
auth_via = "bearer_token"
|
||||
else:
|
||||
logger.error("Failed to verify Google OAuth token")
|
||||
except Exception as e:
|
||||
@@ -252,11 +251,8 @@ class AuthInfoMiddleware(Middleware):
|
||||
context.fastmcp_context.set_state(
|
||||
"authenticated_via", "jwt_token"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"✓ JWT token processed successfully for user: {user_email}"
|
||||
)
|
||||
return
|
||||
authenticated_user = user_email
|
||||
auth_via = "jwt_token"
|
||||
|
||||
except jwt.DecodeError as e:
|
||||
logger.error(f"Failed to decode JWT: {e}", exc_info=True)
|
||||
@@ -273,7 +269,7 @@ class AuthInfoMiddleware(Middleware):
|
||||
|
||||
# After trying HTTP headers, check for other authentication methods
|
||||
# This consolidates all authentication logic in the middleware
|
||||
if not context.fastmcp_context.get_state("authenticated_user_email"):
|
||||
if not authenticated_user:
|
||||
logger.debug(
|
||||
"No authentication found via bearer token, checking other methods"
|
||||
)
|
||||
@@ -317,11 +313,13 @@ class AuthInfoMiddleware(Middleware):
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", "oauth21_stdio"
|
||||
)
|
||||
authenticated_user = requested_user
|
||||
auth_via = "stdio_session"
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking stdio session: {e}")
|
||||
|
||||
# If no requested user was provided but exactly one session exists, assume it in stdio mode
|
||||
if not context.fastmcp_context.get_state("authenticated_user_email"):
|
||||
if not authenticated_user:
|
||||
try:
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
|
||||
@@ -342,15 +340,15 @@ class AuthInfoMiddleware(Middleware):
|
||||
)
|
||||
context.fastmcp_context.set_state("user_email", single_user)
|
||||
context.fastmcp_context.set_state("username", single_user)
|
||||
authenticated_user = single_user
|
||||
auth_via = "stdio_single_session"
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Error determining stdio single-user session: {e}"
|
||||
)
|
||||
|
||||
# Check for MCP session binding
|
||||
if not context.fastmcp_context.get_state(
|
||||
"authenticated_user_email"
|
||||
) and hasattr(context.fastmcp_context, "session_id"):
|
||||
if not authenticated_user and hasattr(context.fastmcp_context, "session_id"):
|
||||
mcp_session_id = context.fastmcp_context.session_id
|
||||
if mcp_session_id:
|
||||
try:
|
||||
@@ -371,9 +369,18 @@ class AuthInfoMiddleware(Middleware):
|
||||
context.fastmcp_context.set_state(
|
||||
"auth_provider_type", "oauth21_session"
|
||||
)
|
||||
authenticated_user = bound_user
|
||||
auth_via = "mcp_session_binding"
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking MCP session binding: {e}")
|
||||
|
||||
# Single exit point with logging
|
||||
if authenticated_user:
|
||||
logger.info(f"✓ Authenticated via {auth_via}: {authenticated_user}")
|
||||
logger.debug(
|
||||
f"Context state after auth: authenticated_user_email={context.fastmcp_context.get_state('authenticated_user_email')}"
|
||||
)
|
||||
|
||||
async def on_call_tool(self, context: MiddlewareContext, call_next):
|
||||
"""Extract auth info from token and set in context state"""
|
||||
logger.debug("Processing tool call authentication")
|
||||
|
||||
Reference in New Issue
Block a user