v3 auth middleware fix

This commit is contained in:
Taylor Wilsdon
2026-02-13 10:20:49 -05:00
4 changed files with 124 additions and 167 deletions

View File

@@ -14,110 +14,15 @@ Usage:
"""
import asyncio
import inspect
import json
import logging
import sys
from typing import Any, Dict, List, Optional
from auth.oauth_config import set_transport_mode
from pydantic_core import PydanticUndefined as _PYDANTIC_UNDEFINED
logger = logging.getLogger(__name__)
_PYDANTIC_UNDEFINED_TYPE = type(_PYDANTIC_UNDEFINED)
def _is_fastapi_param_marker(default: Any) -> bool:
"""
Check if a default value is a FastAPI parameter marker (Body, Query, etc.).
These markers are metadata for HTTP request parsing and should not be passed
directly to tool functions in CLI mode.
"""
default_type = type(default)
return default_type.__module__ == "fastapi.params" and hasattr(
default, "get_default"
)
def _is_required_marker_default(value: Any) -> bool:
"""Check whether a FastAPI/Pydantic default represents a required field."""
if value is Ellipsis or value is inspect.Parameter.empty:
return True
return value is _PYDANTIC_UNDEFINED or isinstance(value, _PYDANTIC_UNDEFINED_TYPE)
def _extract_fastapi_default(default_marker: Any) -> tuple[bool, Any]:
"""
Resolve the runtime default from a FastAPI marker.
Returns:
Tuple of (is_required, resolved_default)
"""
try:
resolved_default = default_marker.get_default(call_default_factory=True)
except TypeError:
# Compatibility path for implementations without call_default_factory kwarg
resolved_default = default_marker.get_default()
except Exception:
resolved_default = getattr(default_marker, "default", inspect.Parameter.empty)
return _is_required_marker_default(resolved_default), resolved_default
def _normalize_cli_args_for_tool(fn, args: Dict[str, Any]) -> Dict[str, Any]:
"""
Fill omitted CLI args for FastAPI markers with their real defaults.
When tools are invoked via HTTP, FastAPI resolves Body/Query/... defaults.
In CLI mode we invoke functions directly, so we need to do that resolution.
"""
normalized_args = dict(args)
signature = inspect.signature(fn)
missing_required = []
for param in signature.parameters.values():
if param.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
continue
if param.name in normalized_args:
continue
if param.default is inspect.Parameter.empty:
continue
if not _is_fastapi_param_marker(param.default):
continue
is_required, resolved_default = _extract_fastapi_default(param.default)
if is_required or resolved_default is inspect.Parameter.empty:
missing_required.append(param.name)
else:
normalized_args[param.name] = resolved_default
if missing_required:
if len(missing_required) == 1:
missing = missing_required[0]
raise TypeError(
f"{fn.__name__}() missing 1 required positional argument: '{missing}'"
)
missing_names = [f"'{name}'" for name in missing_required]
if len(missing_names) == 2:
missing = " and ".join(missing_names)
else:
missing = ", ".join(missing_names[:-1]) + f" and {missing_names[-1]}"
raise TypeError(
f"{fn.__name__}() missing {len(missing_required)} required positional arguments: {missing}"
)
return normalized_args
def get_registered_tools(server) -> Dict[str, Any]:
"""
@@ -336,7 +241,6 @@ async def run_tool(server, tool_name: str, args: Dict[str, Any]) -> str:
call_args = dict(args)
try:
call_args = _normalize_cli_args_for_tool(fn, args)
logger.debug(
f"[CLI] Executing tool: {tool_name} with args: {list(call_args.keys())}"
)