fix cli mode
This commit is contained in:
@@ -14,6 +14,7 @@ Usage:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
@@ -24,6 +25,88 @@ from auth.oauth_config import set_transport_mode
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_fastapi_param_marker(default: Any) -> bool:
|
||||
"""
|
||||
Check if a default value is a FastAPI parameter marker (Body, Query, etc.).
|
||||
|
||||
These markers are metadata for HTTP request parsing and should not be passed
|
||||
directly to tool functions in CLI mode.
|
||||
"""
|
||||
default_type = type(default)
|
||||
return default_type.__module__ == "fastapi.params" and hasattr(
|
||||
default, "get_default"
|
||||
)
|
||||
|
||||
|
||||
def _is_required_marker_default(value: Any) -> bool:
|
||||
"""Check whether a FastAPI/Pydantic default represents a required field."""
|
||||
return value is Ellipsis or type(value).__name__ == "PydanticUndefinedType"
|
||||
|
||||
|
||||
def _extract_fastapi_default(default_marker: Any) -> tuple[bool, Any]:
|
||||
"""
|
||||
Resolve the runtime default from a FastAPI marker.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_required, resolved_default)
|
||||
"""
|
||||
try:
|
||||
resolved_default = default_marker.get_default(call_default_factory=True)
|
||||
except TypeError:
|
||||
# Compatibility path for implementations without call_default_factory kwarg
|
||||
resolved_default = default_marker.get_default()
|
||||
except Exception:
|
||||
resolved_default = getattr(default_marker, "default", inspect.Parameter.empty)
|
||||
|
||||
return _is_required_marker_default(resolved_default), resolved_default
|
||||
|
||||
|
||||
def _normalize_cli_args_for_tool(fn, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Fill omitted CLI args for FastAPI markers with their real defaults.
|
||||
|
||||
When tools are invoked via HTTP, FastAPI resolves Body/Query/... defaults.
|
||||
In CLI mode we invoke functions directly, so we need to do that resolution.
|
||||
"""
|
||||
normalized_args = dict(args)
|
||||
signature = inspect.signature(fn)
|
||||
missing_required = []
|
||||
|
||||
for param in signature.parameters.values():
|
||||
if param.kind in (
|
||||
inspect.Parameter.VAR_POSITIONAL,
|
||||
inspect.Parameter.VAR_KEYWORD,
|
||||
):
|
||||
continue
|
||||
|
||||
if param.name in normalized_args:
|
||||
continue
|
||||
|
||||
if param.default is inspect.Parameter.empty:
|
||||
continue
|
||||
|
||||
if not _is_fastapi_param_marker(param.default):
|
||||
continue
|
||||
|
||||
is_required, resolved_default = _extract_fastapi_default(param.default)
|
||||
if is_required:
|
||||
missing_required.append(param.name)
|
||||
else:
|
||||
normalized_args[param.name] = resolved_default
|
||||
|
||||
if missing_required:
|
||||
if len(missing_required) == 1:
|
||||
missing = missing_required[0]
|
||||
raise TypeError(f"{fn.__name__}() missing 1 required argument: '{missing}'")
|
||||
|
||||
missing = ", ".join(f"'{name}'" for name in missing_required)
|
||||
raise TypeError(
|
||||
f"{fn.__name__}() missing {len(missing_required)} required arguments: {missing}"
|
||||
)
|
||||
|
||||
return normalized_args
|
||||
|
||||
|
||||
def get_registered_tools(server) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all registered tools from the FastMCP server.
|
||||
@@ -233,14 +316,19 @@ async def run_tool(server, tool_name: str, args: Dict[str, Any]) -> str:
|
||||
if fn is None:
|
||||
raise ValueError(f"Tool '{tool_name}' has no callable function")
|
||||
|
||||
logger.debug(f"[CLI] Executing tool: {tool_name} with args: {list(args.keys())}")
|
||||
call_args = dict(args)
|
||||
|
||||
try:
|
||||
call_args = _normalize_cli_args_for_tool(fn, args)
|
||||
logger.debug(
|
||||
f"[CLI] Executing tool: {tool_name} with args: {list(call_args.keys())}"
|
||||
)
|
||||
|
||||
# Call the tool function
|
||||
if asyncio.iscoroutinefunction(fn):
|
||||
result = await fn(**args)
|
||||
result = await fn(**call_args)
|
||||
else:
|
||||
result = fn(**args)
|
||||
result = fn(**call_args)
|
||||
|
||||
# Convert result to string if needed
|
||||
if isinstance(result, str):
|
||||
@@ -257,7 +345,7 @@ async def run_tool(server, tool_name: str, args: Dict[str, Any]) -> str:
|
||||
return (
|
||||
f"Error calling {tool_name}: {error_msg}\n\n"
|
||||
f"Required parameters: {required}\n"
|
||||
f"Provided parameters: {list(args.keys())}"
|
||||
f"Provided parameters: {list(call_args.keys())}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[CLI] Error executing {tool_name}: {e}", exc_info=True)
|
||||
|
||||
Reference in New Issue
Block a user