v3 auth middleware fix

This commit is contained in:
Taylor Wilsdon
2026-02-13 10:20:39 -05:00
parent a3107e900b
commit 0075e8338f
7 changed files with 1086 additions and 1228 deletions

View File

@@ -131,9 +131,14 @@ def get_registered_tools(server) -> Dict[str, Any]:
"""
tools = {}
if hasattr(server, "_tool_manager") and hasattr(server._tool_manager, "_tools"):
tool_registry = server._tool_manager._tools
for name, tool in tool_registry.items():
# FastMCP v3: access tools via local_provider._components
lp = getattr(server, "local_provider", None)
if lp is not None:
components = getattr(lp, "_components", {})
for key, tool in components.items():
if not key.startswith("tool:"):
continue
name = key.split(":", 1)[1].rsplit("@", 1)[0]
tools[name] = {
"name": name,
"description": getattr(tool, "description", None)

View File

@@ -40,9 +40,9 @@ session_middleware = Middleware(MCPSessionMiddleware)
# Custom FastMCP that adds secure middleware stack for OAuth 2.1
class SecureFastMCP(FastMCP):
def streamable_http_app(self) -> "Starlette":
def http_app(self, **kwargs) -> "Starlette":
"""Override to add secure middleware stack for OAuth 2.1."""
app = super().streamable_http_app()
app = super().http_app(**kwargs)
# Add middleware in order (first added = outermost layer)
# Session Management - extracts session info for MCP context

View File

@@ -79,6 +79,22 @@ def wrap_server_tool_method(server):
server.tool = tracking_tool
def _get_tool_components(server) -> dict:
"""Get tool components dict from server's local_provider.
Returns a dict mapping tool_name -> tool_object for introspection.
"""
lp = server.local_provider
components = getattr(lp, "_components", {})
tools = {}
for key, component in components.items():
if key.startswith("tool:"):
# Keys are like "tool:name@version", extract the name
name = key.split(":", 1)[1].rsplit("@", 1)[0]
tools[name] = component
return tools
def filter_server_tools(server):
"""Remove disabled tools from the server after registration."""
enabled_tools = get_enabled_tools()
@@ -87,59 +103,49 @@ def filter_server_tools(server):
return
tools_removed = 0
lp = server.local_provider
tool_components = _get_tool_components(server)
# Access FastMCP's tool registry via _tool_manager._tools
if hasattr(server, "_tool_manager"):
tool_manager = server._tool_manager
if hasattr(tool_manager, "_tools"):
tool_registry = tool_manager._tools
read_only_mode = is_read_only_mode()
allowed_scopes = set(get_all_read_only_scopes()) if read_only_mode else None
read_only_mode = is_read_only_mode()
allowed_scopes = set(get_all_read_only_scopes()) if read_only_mode else None
tools_to_remove = set()
tools_to_remove = set()
# 1. Tier filtering
if enabled_tools is not None:
for tool_name in tool_components:
if not is_tool_enabled(tool_name):
tools_to_remove.add(tool_name)
# 1. Tier filtering
if enabled_tools is not None:
for tool_name in list(tool_registry.keys()):
if not is_tool_enabled(tool_name):
tools_to_remove.add(tool_name)
# 2. OAuth 2.1 filtering
if oauth21_enabled and "start_google_auth" in tool_components:
tools_to_remove.add("start_google_auth")
logger.info("OAuth 2.1 enabled: disabling start_google_auth tool")
# 2. OAuth 2.1 filtering
if oauth21_enabled and "start_google_auth" in tool_registry:
tools_to_remove.add("start_google_auth")
logger.info("OAuth 2.1 enabled: disabling start_google_auth tool")
# 3. Read-only mode filtering
if read_only_mode:
for tool_name, tool_obj in tool_components.items():
if tool_name in tools_to_remove:
continue
# 3. Read-only mode filtering
if read_only_mode:
for tool_name in list(tool_registry.keys()):
if tool_name in tools_to_remove:
continue
# Check if tool has required scopes attached (from @require_google_service)
func_to_check = tool_obj
if hasattr(tool_obj, "fn"):
func_to_check = tool_obj.fn
tool_func = tool_registry[tool_name]
# Check if tool has required scopes attached (from @require_google_service)
# Note: FastMCP wraps functions in Tool objects, so we need to check .fn if available
func_to_check = tool_func
if hasattr(tool_func, "fn"):
func_to_check = tool_func.fn
required_scopes = getattr(func_to_check, "_required_google_scopes", [])
required_scopes = getattr(
func_to_check, "_required_google_scopes", []
if required_scopes:
# If ANY required scope is not in the allowed read-only scopes, disable the tool
if not all(scope in allowed_scopes for scope in required_scopes):
logger.info(
f"Read-only mode: Disabling tool '{tool_name}' (requires write scopes: {required_scopes})"
)
tools_to_remove.add(tool_name)
if required_scopes:
# If ANY required scope is not in the allowed read-only scopes, disable the tool
if not all(
scope in allowed_scopes for scope in required_scopes
):
logger.info(
f"Read-only mode: Disabling tool '{tool_name}' (requires write scopes: {required_scopes})"
)
tools_to_remove.add(tool_name)
for tool_name in tools_to_remove:
del tool_registry[tool_name]
tools_removed += 1
for tool_name in tools_to_remove:
lp.remove_tool(tool_name)
tools_removed += 1
if tools_removed > 0:
enabled_count = len(enabled_tools) if enabled_tools is not None else "all"