refac fastmcp custom implementation to standard

This commit is contained in:
Taylor Wilsdon
2025-05-30 11:09:56 -04:00
parent b00fe41de8
commit 395d02494b
7 changed files with 261 additions and 730 deletions

View File

@@ -1,20 +1,14 @@
import logging
import os
import threading
from typing import Dict, Any, Optional
from fastapi import Request, Header
from fastapi import Header
from fastapi.responses import HTMLResponse
from mcp import types
from mcp.server.fastmcp import FastMCP
from fastmcp.server.dependencies import get_http_request
from starlette.requests import Request
from starlette.applications import Starlette
# Import our custom StreamableHTTP session manager
from core.streamable_http import SessionAwareStreamableHTTPManager, create_starlette_app
from google.auth.exceptions import RefreshError
from auth.google_auth import handle_auth_callback, start_auth_flow, CONFIG_CLIENT_SECRETS_PATH
@@ -50,86 +44,9 @@ WORKSPACE_MCP_BASE_URI = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
# Basic MCP server instance
server = FastMCP(
name="google_workspace",
server_url=f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/gworkspace", # Add absolute URL for Gemini native function calling
host="0.0.0.0", # Listen on all interfaces
port=WORKSPACE_MCP_PORT, # Default port for HTTP server
stateless_http=False # Enable stateful sessions (default)
server_url=f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/mcp", # Add absolute URL for Gemini native function calling
)
# Container for session manager
class SessionManagerContainer:
"""
Thread-safe container for the session manager instance.
This encapsulates the session manager to improve testability and thread safety
by avoiding direct global variable access.
"""
def __init__(self):
self._session_manager = None
self._lock = threading.Lock()
def set_session_manager(self, manager) -> None:
"""Set the session manager instance in a thread-safe manner."""
with self._lock:
self._session_manager = manager
def get_session_manager(self):
"""Get the current session manager instance in a thread-safe manner."""
with self._lock:
return self._session_manager
session_manager_container = SessionManagerContainer()
def get_session_manager():
"""
Get the current session manager instance.
Returns:
The session manager instance if initialized, None otherwise
"""
return session_manager_container.get_session_manager()
def log_all_active_sessions():
"""
Log information about all active sessions for debugging purposes.
"""
session_manager = session_manager_container.get_session_manager()
if session_manager is None:
logger.debug("Cannot log sessions: session_manager is not initialized")
return
active_sessions = session_manager.get_active_sessions()
session_count = len(active_sessions)
logger.debug(f"Active sessions: {session_count}")
for session_id, info in active_sessions.items():
logger.debug(f"Session ID: {session_id}, Created: {info.get('created_at')}, Last Active: {info.get('last_active')}")
def create_application(base_path="/gworkspace") -> Starlette:
"""
Create a Starlette application with the MCP server and session manager.
Args:
base_path: The base path to mount the MCP server at
Returns:
A Starlette application
"""
logger.info(f"Creating Starlette application with MCP server mounted at {base_path}")
app, manager = create_starlette_app(server._mcp_server, base_path)
session_manager_container.set_session_manager(manager)
# Add the OAuth callback route to the Starlette application
from starlette.routing import Route
# Add the OAuth callback route
app.routes.append(
Route("/oauth2callback", endpoint=oauth2_callback, methods=["GET"])
)
return app
# Configure OAuth redirect URI to use the MCP server's base uri and port
OAUTH_REDIRECT_URI = f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}/oauth2callback"
@@ -292,54 +209,6 @@ async def start_google_auth(
redirect_uri=OAUTH_REDIRECT_URI
)
@server.tool()
async def get_active_sessions() -> Dict[str, Any]:
"""
Retrieve information about all active MCP sessions.
LLM Guidance:
- Use this tool to get information about currently active sessions
- This is useful for debugging or when you need to understand the active user sessions
Returns:
A dictionary mapping session IDs to session information
"""
session_manager = session_manager_container.get_session_manager()
if session_manager is None:
logger.error("get_active_sessions called but session_manager is not initialized")
return {"error": "Session manager not initialized"}
active_sessions = session_manager.get_active_sessions()
return active_sessions
@server.tool()
async def get_session_info(session_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieve information about a specific MCP session.
LLM Guidance:
- Use this tool when you need details about a specific session
- Provide the session_id parameter to identify which session to retrieve
Args:
session_id: The ID of the session to retrieve
Returns:
Session information if found, None otherwise
"""
session_manager = session_manager_container.get_session_manager()
if session_manager is None:
logger.error(f"get_session_info({session_id}) called but session_manager is not initialized")
return {"error": "Session manager not initialized"}
session_info = session_manager.get_session(session_id)
if session_info is None:
logger.debug(f"Session {session_id} not found")
return {"error": f"Session {session_id} not found"}
return session_info
@server.tool()
async def debug_current_session(
mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id")
@@ -349,46 +218,23 @@ async def debug_current_session(
LLM Guidance:
- Use this tool to verify that session tracking is working correctly
- This tool will return information about the current session based on the Mcp-Session-Id header
- This tool will return the current session ID from the Mcp-Session-Id header
Args:
mcp_session_id: The MCP session ID header (automatically injected)
Returns:
Information about the current session and all active sessions
Information about the current session
"""
session_manager = session_manager_container.get_session_manager()
# Get the HTTP request to access headers
req: Request = get_http_request()
headers = dict(req.headers)
# Log all active sessions for debugging
log_all_active_sessions()
result = {
"current_session": {
"session_id": mcp_session_id,
"headers": headers
},
"session_info": None,
"active_sessions_count": 0
"message": "Session tracking is handled natively by FastMCP"
}
}
# Get info for the current session if available
if session_manager is not None and mcp_session_id:
session_info = session_manager.get_session(mcp_session_id)
result["session_info"] = session_info
# Count active sessions
active_sessions = session_manager.get_active_sessions()
result["active_sessions_count"] = len(active_sessions)
result["active_session_ids"] = list(active_sessions.keys())
else:
result["error"] = "Unable to retrieve session information"
if session_manager is None:
result["error_details"] = "Session manager not initialized"
elif not mcp_session_id:
result["error_details"] = "No session ID provided in request"
if not mcp_session_id:
result["error"] = "No session ID provided in request"
result["error_details"] = "The Mcp-Session-Id header was not found"
return result