fix stdio

This commit is contained in:
Taylor Wilsdon
2025-08-05 11:22:22 -04:00
parent ee792b65d8
commit e03b10c024
4 changed files with 190 additions and 324 deletions

View File

@@ -6,7 +6,6 @@ import logging
import os
import time
from types import SimpleNamespace
from typing import Any, Dict
from fastmcp.server.middleware import Middleware, MiddlewareContext
from fastmcp.server.dependencies import get_http_headers
@@ -30,9 +29,9 @@ class AuthInfoMiddleware(Middleware):
logger.warning("No fastmcp_context available")
return
# Return early if token is already in state
if context.fastmcp_context.get_state("access_token"):
logger.info("Access token already in state.")
# Return early if authentication state is already set
if context.fastmcp_context.get_state("authenticated_user_email"):
logger.info("Authentication state already set.")
return
# Try to get the HTTP request to extract Authorization header
@@ -86,16 +85,20 @@ class AuthInfoMiddleware(Middleware):
email=user_email
)
# Store in context state
# Store in context state - this is the authoritative authentication state
context.fastmcp_context.set_state("access_token", access_token)
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
context.fastmcp_context.set_state("token_type", "google_oauth")
context.fastmcp_context.set_state("user_email", user_email)
context.fastmcp_context.set_state("username", user_email)
# Set the definitive authentication state
context.fastmcp_context.set_state("authenticated_user_email", user_email)
context.fastmcp_context.set_state("authenticated_via", "bearer_token")
logger.info(f"Stored verified Google OAuth token for user: {user_email}")
else:
logger.error("Failed to verify Google OAuth token")
# Don't set authenticated_user_email if verification failed
except Exception as e:
logger.error(f"Error verifying Google OAuth token: {e}")
# Still store the unverified token - service decorator will handle verification
@@ -158,6 +161,12 @@ class AuthInfoMiddleware(Middleware):
context.fastmcp_context.set_state("jti", token_payload.get("jti"))
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
# Set the definitive authentication state for JWT tokens
user_email = token_payload.get("email", token_payload.get("username"))
if user_email:
context.fastmcp_context.set_state("authenticated_user_email", user_email)
context.fastmcp_context.set_state("authenticated_via", "jwt_token")
logger.info("Successfully extracted and stored auth info from HTTP request")
except jwt.DecodeError as e:
@@ -170,6 +179,61 @@ class AuthInfoMiddleware(Middleware):
logger.debug("No HTTP headers available (might be using stdio transport)")
except Exception as e:
logger.debug(f"Could not get HTTP request: {e}")
# After trying HTTP headers, check for other authentication methods
# This consolidates all authentication logic in the middleware
if not context.fastmcp_context.get_state("authenticated_user_email"):
logger.debug("No authentication found via bearer token, checking other methods")
# Check transport mode
from core.config import get_transport_mode
transport_mode = get_transport_mode()
if transport_mode == "stdio":
# In stdio mode, check if there's a session with credentials
# This is ONLY safe in stdio mode because it's single-user
logger.debug("Checking for stdio mode authentication")
# Get the requested user from the context if available
requested_user = None
if hasattr(context, 'request') and hasattr(context.request, 'params'):
requested_user = context.request.params.get('user_google_email')
elif hasattr(context, 'arguments'):
# FastMCP may store arguments differently
requested_user = context.arguments.get('user_google_email')
if requested_user:
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
# Check if user has a recent session
if store.has_session(requested_user):
logger.info(f"User {requested_user} has recent auth session in stdio mode")
# In stdio mode, we can trust the user has authenticated recently
context.fastmcp_context.set_state("authenticated_user_email", requested_user)
context.fastmcp_context.set_state("authenticated_via", "stdio_session")
context.fastmcp_context.set_state("auth_provider_type", "oauth21_stdio")
except Exception as e:
logger.debug(f"Error checking stdio session: {e}")
# Check for MCP session binding
if not context.fastmcp_context.get_state("authenticated_user_email") and hasattr(context.fastmcp_context, 'session_id'):
mcp_session_id = context.fastmcp_context.session_id
if mcp_session_id:
try:
from auth.oauth21_session_store import get_oauth21_session_store
store = get_oauth21_session_store()
# Check if this MCP session is bound to a user
bound_user = store.get_user_by_mcp_session(mcp_session_id)
if bound_user:
logger.info(f"MCP session {mcp_session_id} is bound to user {bound_user}")
context.fastmcp_context.set_state("authenticated_user_email", bound_user)
context.fastmcp_context.set_state("authenticated_via", "mcp_session_binding")
context.fastmcp_context.set_state("auth_provider_type", "oauth21_session")
except Exception as e:
logger.debug(f"Error checking MCP session binding: {e}")
async def on_call_tool(self, context: MiddlewareContext, call_next):
"""Extract auth info from token and set in context state"""