refac fastmcp custom implementation to standard

This commit is contained in:
Taylor Wilsdon
2025-05-30 11:09:56 -04:00
parent b00fe41de8
commit 395d02494b
7 changed files with 261 additions and 730 deletions

View File

@@ -1,20 +1,14 @@
import logging
import os
import threading
from typing import Dict, Any, Optional
from fastapi import Request, Header
from fastapi import Header
from fastapi.responses import HTMLResponse
from mcp import types
from mcp.server.fastmcp import FastMCP
from fastmcp.server.dependencies import get_http_request
from starlette.requests import Request
from starlette.applications import Starlette
# Import our custom StreamableHTTP session manager
from core.streamable_http import SessionAwareStreamableHTTPManager, create_starlette_app
from google.auth.exceptions import RefreshError
from auth.google_auth import handle_auth_callback, start_auth_flow, CONFIG_CLIENT_SECRETS_PATH
@@ -50,86 +44,9 @@ WORKSPACE_MCP_BASE_URI = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
# Basic MCP server instance
server = FastMCP(
name="google_workspace",
server_url=f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/gworkspace", # Add absolute URL for Gemini native function calling
host="0.0.0.0", # Listen on all interfaces
port=WORKSPACE_MCP_PORT, # Default port for HTTP server
stateless_http=False # Enable stateful sessions (default)
server_url=f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/mcp", # Add absolute URL for Gemini native function calling
)
# Container for session manager
class SessionManagerContainer:
"""
Thread-safe container for the session manager instance.
This encapsulates the session manager to improve testability and thread safety
by avoiding direct global variable access.
"""
def __init__(self):
self._session_manager = None
self._lock = threading.Lock()
def set_session_manager(self, manager) -> None:
"""Set the session manager instance in a thread-safe manner."""
with self._lock:
self._session_manager = manager
def get_session_manager(self):
"""Get the current session manager instance in a thread-safe manner."""
with self._lock:
return self._session_manager
session_manager_container = SessionManagerContainer()
def get_session_manager():
"""
Get the current session manager instance.
Returns:
The session manager instance if initialized, None otherwise
"""
return session_manager_container.get_session_manager()
def log_all_active_sessions():
"""
Log information about all active sessions for debugging purposes.
"""
session_manager = session_manager_container.get_session_manager()
if session_manager is None:
logger.debug("Cannot log sessions: session_manager is not initialized")
return
active_sessions = session_manager.get_active_sessions()
session_count = len(active_sessions)
logger.debug(f"Active sessions: {session_count}")
for session_id, info in active_sessions.items():
logger.debug(f"Session ID: {session_id}, Created: {info.get('created_at')}, Last Active: {info.get('last_active')}")
def create_application(base_path="/gworkspace") -> Starlette:
"""
Create a Starlette application with the MCP server and session manager.
Args:
base_path: The base path to mount the MCP server at
Returns:
A Starlette application
"""
logger.info(f"Creating Starlette application with MCP server mounted at {base_path}")
app, manager = create_starlette_app(server._mcp_server, base_path)
session_manager_container.set_session_manager(manager)
# Add the OAuth callback route to the Starlette application
from starlette.routing import Route
# Add the OAuth callback route
app.routes.append(
Route("/oauth2callback", endpoint=oauth2_callback, methods=["GET"])
)
return app
# Configure OAuth redirect URI to use the MCP server's base uri and port
OAUTH_REDIRECT_URI = f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/oauth2callback"
@@ -292,54 +209,6 @@ async def start_google_auth(
redirect_uri=OAUTH_REDIRECT_URI
)
@server.tool()
async def get_active_sessions() -> Dict[str, Any]:
"""
Retrieve information about all active MCP sessions.
LLM Guidance:
- Use this tool to get information about currently active sessions
- This is useful for debugging or when you need to understand the active user sessions
Returns:
A dictionary mapping session IDs to session information
"""
session_manager = session_manager_container.get_session_manager()
if session_manager is None:
logger.error("get_active_sessions called but session_manager is not initialized")
return {"error": "Session manager not initialized"}
active_sessions = session_manager.get_active_sessions()
return active_sessions
@server.tool()
async def get_session_info(session_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieve information about a specific MCP session.
LLM Guidance:
- Use this tool when you need details about a specific session
- Provide the session_id parameter to identify which session to retrieve
Args:
session_id: The ID of the session to retrieve
Returns:
Session information if found, None otherwise
"""
session_manager = session_manager_container.get_session_manager()
if session_manager is None:
logger.error(f"get_session_info({session_id}) called but session_manager is not initialized")
return {"error": "Session manager not initialized"}
session_info = session_manager.get_session(session_id)
if session_info is None:
logger.debug(f"Session {session_id} not found")
return {"error": f"Session {session_id} not found"}
return session_info
@server.tool()
async def debug_current_session(
mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id")
@@ -349,46 +218,23 @@ async def debug_current_session(
LLM Guidance:
- Use this tool to verify that session tracking is working correctly
- This tool will return information about the current session based on the Mcp-Session-Id header
- This tool will return the current session ID from the Mcp-Session-Id header
Args:
mcp_session_id: The MCP session ID header (automatically injected)
Returns:
Information about the current session and all active sessions
Information about the current session
"""
session_manager = session_manager_container.get_session_manager()
# Get the HTTP request to access headers
req: Request = get_http_request()
headers = dict(req.headers)
# Log all active sessions for debugging
log_all_active_sessions()
result = {
"current_session": {
"session_id": mcp_session_id,
"headers": headers
},
"session_info": None,
"active_sessions_count": 0
"message": "Session tracking is handled natively by FastMCP"
}
}
# Get info for the current session if available
if session_manager is not None and mcp_session_id:
session_info = session_manager.get_session(mcp_session_id)
result["session_info"] = session_info
# Count active sessions
active_sessions = session_manager.get_active_sessions()
result["active_sessions_count"] = len(active_sessions)
result["active_session_ids"] = list(active_sessions.keys())
else:
result["error"] = "Unable to retrieve session information"
if session_manager is None:
result["error_details"] = "Session manager not initialized"
elif not mcp_session_id:
result["error_details"] = "No session ID provided in request"
if not mcp_session_id:
result["error"] = "No session ID provided in request"
result["error_details"] = "The Mcp-Session-Id header was not found"
return result

View File

@@ -1,262 +0,0 @@
import contextlib
import logging
import os
from collections.abc import AsyncIterator
from typing import Optional, Dict, Any, Callable
import anyio
from mcp.server.lowlevel import Server
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.routing import Mount
from starlette.types import Receive, Scope, Send, ASGIApp
# Global variable to store the current session ID for the current request
# This will be used to pass the session ID to the FastMCP tools
CURRENT_SESSION_ID = None
logger = logging.getLogger(__name__)
class SessionAwareStreamableHTTPManager:
"""
A wrapper around StreamableHTTPSessionManager that provides access to session information.
This class enables retrieving active session data which can be useful for tools that need
to know about current sessions.
"""
def __init__(self, app: Server, stateless: bool = False, event_store: Optional[Any] = None):
"""
Initialize the session manager wrapper.
Args:
app: The MCP Server instance
stateless: Whether to use stateless mode (default: False)
event_store: Optional event store for storing session events
"""
self.session_manager = StreamableHTTPSessionManager(
app=app,
event_store=event_store,
stateless=stateless
)
self._sessions = {}
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
Handle an incoming request by delegating to the underlying session manager.
Args:
scope: The ASGI scope
receive: The ASGI receive function
send: The ASGI send function
"""
global CURRENT_SESSION_ID
# Check for session ID in headers
headers = dict(scope.get("headers", []))
session_id = None
for key, value in headers.items():
if key.lower() == b'mcp-session-id':
session_id = value.decode('utf-8')
break
# Extract session ID from StreamableHTTP if not in headers
if not session_id and hasattr(self.session_manager, '_get_session_id_from_scope'):
try:
# Try to get session ID directly from StreamableHTTP manager
session_id = self.session_manager._get_session_id_from_scope(scope)
except Exception:
pass
# Set the global session ID for this request
if session_id:
CURRENT_SESSION_ID = session_id
# Inject the session ID into the request headers
# This allows FastMCP to access it via the Header mechanism
new_headers = []
has_session_header = False
for k, v in scope.get("headers", []):
if k.lower() == b'mcp-session-id':
new_headers.append((k, session_id.encode('utf-8')))
has_session_header = True
else:
new_headers.append((k, v))
# Add the header if it doesn't exist
if not has_session_header:
new_headers.append((b'mcp-session-id', session_id.encode('utf-8')))
# Replace headers in scope
scope["headers"] = new_headers
else:
CURRENT_SESSION_ID = None
# Create a wrapper for the send function to capture the session ID
# from the response if it's not in the request
original_send = send
async def wrapped_send(message):
nonlocal session_id
global CURRENT_SESSION_ID
# If this is a response, check for session ID in headers
if message.get("type") == "http.response.start" and not session_id:
headers = message.get("headers", [])
for k, v in headers:
if k.lower() == b'mcp-session-id':
new_session_id = v.decode('utf-8')
CURRENT_SESSION_ID = new_session_id
break
await original_send(message)
# Process the request with the wrapped send function
await self.session_manager.handle_request(scope, receive, wrapped_send)
# Clear the global session ID after the request is done
CURRENT_SESSION_ID = None
@contextlib.asynccontextmanager
async def run(self) -> AsyncIterator[None]:
"""
Context manager for running the session manager.
Yields:
None
"""
async with self.session_manager.run():
logger.debug("SessionAwareStreamableHTTPManager started")
try:
yield
finally:
logger.debug("SessionAwareStreamableHTTPManager shutting down")
def get_active_sessions(self) -> Dict[str, Any]:
"""
Get information about all active sessions.
Returns:
A dictionary mapping session IDs to session information
"""
# Access the internal sessions dictionary from the session manager
if hasattr(self.session_manager, '_sessions'):
return {
session_id: {
"created_at": session.created_at,
"last_active": session.last_active,
"client_id": session.client_id if hasattr(session, 'client_id') else None,
}
for session_id, session in self.session_manager._sessions.items()
}
return {}
def get_session(self, session_id: str) -> Optional[Any]:
"""
Get information about a specific session.
Args:
session_id: The ID of the session to retrieve
Returns:
Session information if found, None otherwise
"""
if hasattr(self.session_manager, '_sessions') and session_id in self.session_manager._sessions:
session = self.session_manager._sessions[session_id]
return {
"created_at": session.created_at,
"last_active": session.last_active,
"client_id": session.client_id if hasattr(session, 'client_id') else None,
}
return None
def create_starlette_app(mcp_server: Server, base_path: str = "/mcp") -> Starlette:
"""
Create a Starlette application with a mounted StreamableHTTPSessionManager.
Args:
mcp_server: The MCP Server instance
base_path: The base path to mount the session manager at
Returns:
A Starlette application
"""
session_manager = SessionAwareStreamableHTTPManager(
app=mcp_server,
stateless=False, # Use stateful sessions by default
)
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
# Log information about the incoming request
path = scope.get("path", "unknown path")
method = scope.get("method", "unknown method")
logger.info(f"Incoming request: {method} {path}")
# Process the request
await session_manager.handle_request(scope, receive, send)
@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[None]:
"""Context manager for session manager."""
async with session_manager.run():
logger.info(f"Application started with StreamableHTTP session manager at {base_path}")
try:
yield
finally:
logger.info("Application shutting down...")
app = Starlette(
debug=True,
routes=[
Mount(base_path, app=handle_streamable_http),
],
lifespan=lifespan,
)
# Create a middleware to set the FastMCP session ID header
class SessionHeaderMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# If we have a current session ID, add it to the request's headers
# This makes it available to FastMCP via the Header injection
if CURRENT_SESSION_ID:
# Save the session ID in an environment variable that FastMCP can access
os.environ["MCP_CURRENT_SESSION_ID"] = CURRENT_SESSION_ID
# Since we can't modify request.headers directly,
# we'll handle this in our SessionAwareStreamableHTTPManager
logger.debug(f"SessionHeaderMiddleware: Set environment session ID to {CURRENT_SESSION_ID}")
# Call the next middleware or endpoint
response = await call_next(request)
# Remove the environment variable after the request
if "MCP_CURRENT_SESSION_ID" in os.environ:
del os.environ["MCP_CURRENT_SESSION_ID"]
return response
# Add the middleware to the app
app.add_middleware(SessionHeaderMiddleware)
# Attach the session manager to the app for access elsewhere
app.state.session_manager = session_manager
return app, session_manager
# Function to get the current session ID (used by tools)
def get_current_session_id() -> Optional[str]:
"""
Get the session ID for the current request context.
Returns:
The session ID if available, None otherwise
"""
# First check the global variable (set during request handling)
if CURRENT_SESSION_ID:
return CURRENT_SESSION_ID
# Then check environment variable (set by middleware)
return os.environ.get("MCP_CURRENT_SESSION_ID")