fix stdio
This commit is contained in:
@@ -6,7 +6,6 @@ import logging
|
||||
import os
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict
|
||||
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
@@ -30,9 +29,9 @@ class AuthInfoMiddleware(Middleware):
|
||||
logger.warning("No fastmcp_context available")
|
||||
return
|
||||
|
||||
# Return early if token is already in state
|
||||
if context.fastmcp_context.get_state("access_token"):
|
||||
logger.info("Access token already in state.")
|
||||
# Return early if authentication state is already set
|
||||
if context.fastmcp_context.get_state("authenticated_user_email"):
|
||||
logger.info("Authentication state already set.")
|
||||
return
|
||||
|
||||
# Try to get the HTTP request to extract Authorization header
|
||||
@@ -86,16 +85,20 @@ class AuthInfoMiddleware(Middleware):
|
||||
email=user_email
|
||||
)
|
||||
|
||||
# Store in context state
|
||||
# Store in context state - this is the authoritative authentication state
|
||||
context.fastmcp_context.set_state("access_token", access_token)
|
||||
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
|
||||
context.fastmcp_context.set_state("token_type", "google_oauth")
|
||||
context.fastmcp_context.set_state("user_email", user_email)
|
||||
context.fastmcp_context.set_state("username", user_email)
|
||||
# Set the definitive authentication state
|
||||
context.fastmcp_context.set_state("authenticated_user_email", user_email)
|
||||
context.fastmcp_context.set_state("authenticated_via", "bearer_token")
|
||||
|
||||
logger.info(f"Stored verified Google OAuth token for user: {user_email}")
|
||||
else:
|
||||
logger.error("Failed to verify Google OAuth token")
|
||||
# Don't set authenticated_user_email if verification failed
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying Google OAuth token: {e}")
|
||||
# Still store the unverified token - service decorator will handle verification
|
||||
@@ -158,6 +161,12 @@ class AuthInfoMiddleware(Middleware):
|
||||
context.fastmcp_context.set_state("jti", token_payload.get("jti"))
|
||||
context.fastmcp_context.set_state("auth_provider_type", self.auth_provider_type)
|
||||
|
||||
# Set the definitive authentication state for JWT tokens
|
||||
user_email = token_payload.get("email", token_payload.get("username"))
|
||||
if user_email:
|
||||
context.fastmcp_context.set_state("authenticated_user_email", user_email)
|
||||
context.fastmcp_context.set_state("authenticated_via", "jwt_token")
|
||||
|
||||
logger.info("Successfully extracted and stored auth info from HTTP request")
|
||||
|
||||
except jwt.DecodeError as e:
|
||||
@@ -170,6 +179,61 @@ class AuthInfoMiddleware(Middleware):
|
||||
logger.debug("No HTTP headers available (might be using stdio transport)")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get HTTP request: {e}")
|
||||
|
||||
# After trying HTTP headers, check for other authentication methods
|
||||
# This consolidates all authentication logic in the middleware
|
||||
if not context.fastmcp_context.get_state("authenticated_user_email"):
|
||||
logger.debug("No authentication found via bearer token, checking other methods")
|
||||
|
||||
# Check transport mode
|
||||
from core.config import get_transport_mode
|
||||
transport_mode = get_transport_mode()
|
||||
|
||||
if transport_mode == "stdio":
|
||||
# In stdio mode, check if there's a session with credentials
|
||||
# This is ONLY safe in stdio mode because it's single-user
|
||||
logger.debug("Checking for stdio mode authentication")
|
||||
|
||||
# Get the requested user from the context if available
|
||||
requested_user = None
|
||||
if hasattr(context, 'request') and hasattr(context.request, 'params'):
|
||||
requested_user = context.request.params.get('user_google_email')
|
||||
elif hasattr(context, 'arguments'):
|
||||
# FastMCP may store arguments differently
|
||||
requested_user = context.arguments.get('user_google_email')
|
||||
|
||||
if requested_user:
|
||||
try:
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
# Check if user has a recent session
|
||||
if store.has_session(requested_user):
|
||||
logger.info(f"User {requested_user} has recent auth session in stdio mode")
|
||||
# In stdio mode, we can trust the user has authenticated recently
|
||||
context.fastmcp_context.set_state("authenticated_user_email", requested_user)
|
||||
context.fastmcp_context.set_state("authenticated_via", "stdio_session")
|
||||
context.fastmcp_context.set_state("auth_provider_type", "oauth21_stdio")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking stdio session: {e}")
|
||||
|
||||
# Check for MCP session binding
|
||||
if not context.fastmcp_context.get_state("authenticated_user_email") and hasattr(context.fastmcp_context, 'session_id'):
|
||||
mcp_session_id = context.fastmcp_context.session_id
|
||||
if mcp_session_id:
|
||||
try:
|
||||
from auth.oauth21_session_store import get_oauth21_session_store
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
# Check if this MCP session is bound to a user
|
||||
bound_user = store.get_user_by_mcp_session(mcp_session_id)
|
||||
if bound_user:
|
||||
logger.info(f"MCP session {mcp_session_id} is bound to user {bound_user}")
|
||||
context.fastmcp_context.set_state("authenticated_user_email", bound_user)
|
||||
context.fastmcp_context.set_state("authenticated_via", "mcp_session_binding")
|
||||
context.fastmcp_context.set_state("auth_provider_type", "oauth21_session")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking MCP session binding: {e}")
|
||||
|
||||
async def on_call_tool(self, context: MiddlewareContext, call_next):
|
||||
"""Extract auth info from token and set in context state"""
|
||||
|
||||
Reference in New Issue
Block a user