feat: initial commit from workspace-mcp
Some checks failed
Check Maintainer Edits Enabled / check-maintainer-edits (pull_request) Has been cancelled
Check Maintainer Edits Enabled / check-maintainer-edits-internal (pull_request) Has been cancelled
Docker Build and Push to GHCR / build-and-push (pull_request) Has been cancelled
Ruff / ruff (pull_request) Has been cancelled
Some checks failed
Check Maintainer Edits Enabled / check-maintainer-edits (pull_request) Has been cancelled
Check Maintainer Edits Enabled / check-maintainer-edits-internal (pull_request) Has been cancelled
Docker Build and Push to GHCR / build-and-push (pull_request) Has been cancelled
Ruff / ruff (pull_request) Has been cancelled
This commit is contained in:
1
core/__init__.py
Normal file
1
core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Make the core directory a Python package
|
||||
108
core/api_enablement.py
Normal file
108
core/api_enablement.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import re
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
|
||||
API_ENABLEMENT_LINKS: Dict[str, str] = {
|
||||
"calendar-json.googleapis.com": "https://console.cloud.google.com/flows/enableapi?apiid=calendar-json.googleapis.com",
|
||||
"drive.googleapis.com": "https://console.cloud.google.com/flows/enableapi?apiid=drive.googleapis.com",
|
||||
"gmail.googleapis.com": "https://console.cloud.google.com/flows/enableapi?apiid=gmail.googleapis.com",
|
||||
"docs.googleapis.com": "https://console.cloud.google.com/flows/enableapi?apiid=docs.googleapis.com",
|
||||
"sheets.googleapis.com": "https://console.cloud.google.com/flows/enableapi?apiid=sheets.googleapis.com",
|
||||
"slides.googleapis.com": "https://console.cloud.google.com/flows/enableapi?apiid=slides.googleapis.com",
|
||||
"forms.googleapis.com": "https://console.cloud.google.com/flows/enableapi?apiid=forms.googleapis.com",
|
||||
"tasks.googleapis.com": "https://console.cloud.google.com/flows/enableapi?apiid=tasks.googleapis.com",
|
||||
"chat.googleapis.com": "https://console.cloud.google.com/flows/enableapi?apiid=chat.googleapis.com",
|
||||
"customsearch.googleapis.com": "https://console.cloud.google.com/flows/enableapi?apiid=customsearch.googleapis.com",
|
||||
}
|
||||
|
||||
|
||||
SERVICE_NAME_TO_API: Dict[str, str] = {
|
||||
"Google Calendar": "calendar-json.googleapis.com",
|
||||
"Google Drive": "drive.googleapis.com",
|
||||
"Gmail": "gmail.googleapis.com",
|
||||
"Google Docs": "docs.googleapis.com",
|
||||
"Google Sheets": "sheets.googleapis.com",
|
||||
"Google Slides": "slides.googleapis.com",
|
||||
"Google Forms": "forms.googleapis.com",
|
||||
"Google Tasks": "tasks.googleapis.com",
|
||||
"Google Chat": "chat.googleapis.com",
|
||||
"Google Custom Search": "customsearch.googleapis.com",
|
||||
}
|
||||
|
||||
|
||||
INTERNAL_SERVICE_TO_API: Dict[str, str] = {
|
||||
"calendar": "calendar-json.googleapis.com",
|
||||
"drive": "drive.googleapis.com",
|
||||
"gmail": "gmail.googleapis.com",
|
||||
"docs": "docs.googleapis.com",
|
||||
"sheets": "sheets.googleapis.com",
|
||||
"slides": "slides.googleapis.com",
|
||||
"forms": "forms.googleapis.com",
|
||||
"tasks": "tasks.googleapis.com",
|
||||
"chat": "chat.googleapis.com",
|
||||
"customsearch": "customsearch.googleapis.com",
|
||||
"search": "customsearch.googleapis.com",
|
||||
}
|
||||
|
||||
|
||||
def extract_api_info_from_error(
|
||||
error_details: str,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract API service and project ID from error details.
|
||||
|
||||
Returns:
|
||||
Tuple of (api_service, project_id) or (None, None) if not found
|
||||
"""
|
||||
api_pattern = r"https://console\.developers\.google\.com/apis/api/([^/]+)/overview"
|
||||
project_pattern = r"project[=\s]+([a-zA-Z0-9-]+)"
|
||||
|
||||
api_match = re.search(api_pattern, error_details)
|
||||
project_match = re.search(project_pattern, error_details)
|
||||
|
||||
api_service = api_match.group(1) if api_match else None
|
||||
project_id = project_match.group(1) if project_match else None
|
||||
|
||||
return api_service, project_id
|
||||
|
||||
|
||||
def get_api_enablement_message(
|
||||
error_details: str, service_type: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a helpful error message with direct API enablement link.
|
||||
|
||||
Args:
|
||||
error_details: The error details string from the HttpError
|
||||
service_type: Optional service type (e.g., "calendar", "gmail", or "Google Calendar")
|
||||
|
||||
Returns:
|
||||
Formatted error message with enablement link
|
||||
"""
|
||||
api_service, project_id = extract_api_info_from_error(error_details)
|
||||
|
||||
if not api_service and service_type:
|
||||
# Check internal service names first (e.g., "calendar", "gmail")
|
||||
api_service = INTERNAL_SERVICE_TO_API.get(service_type)
|
||||
if not api_service:
|
||||
# Check display names (e.g., "Google Calendar")
|
||||
api_service = SERVICE_NAME_TO_API.get(service_type)
|
||||
|
||||
if api_service and api_service in API_ENABLEMENT_LINKS:
|
||||
enable_link = API_ENABLEMENT_LINKS[api_service]
|
||||
service_display_name = next(
|
||||
(name for name, api in SERVICE_NAME_TO_API.items() if api == api_service),
|
||||
api_service,
|
||||
)
|
||||
|
||||
message = (
|
||||
f"{service_display_name} API is not enabled for your project"
|
||||
f"{f' ({project_id})' if project_id else ''}.\n\n"
|
||||
f"Enable it here: {enable_link}\n\n"
|
||||
f"After enabling, wait 1-2 minutes for the change to propagate, then try again."
|
||||
f"IMPORTANT - LLM: share the link provided as a clickable hyperlink and instruct the user to enable the required API."
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
return ""
|
||||
262
core/attachment_storage.py
Normal file
262
core/attachment_storage.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
Temporary attachment storage for Gmail attachments.
|
||||
|
||||
Stores attachments to local disk and returns file paths for direct access.
|
||||
Files are automatically cleaned up after expiration (default 1 hour).
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Optional, Dict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default expiration: 1 hour
|
||||
DEFAULT_EXPIRATION_SECONDS = 3600
|
||||
|
||||
# Storage directory - configurable via WORKSPACE_ATTACHMENT_DIR env var
|
||||
# Uses absolute path to avoid creating tmp/ in arbitrary working directories (see #327)
|
||||
_default_dir = str(Path.home() / ".workspace-mcp" / "attachments")
|
||||
STORAGE_DIR = (
|
||||
Path(os.getenv("WORKSPACE_ATTACHMENT_DIR", _default_dir)).expanduser().resolve()
|
||||
)
|
||||
|
||||
|
||||
def _ensure_storage_dir() -> None:
|
||||
"""Create the storage directory on first use, not at import time."""
|
||||
STORAGE_DIR.mkdir(parents=True, exist_ok=True, mode=0o700)
|
||||
|
||||
|
||||
class SavedAttachment(NamedTuple):
|
||||
"""Result of saving an attachment: provides both the UUID and the absolute file path."""
|
||||
|
||||
file_id: str
|
||||
path: str
|
||||
|
||||
|
||||
class AttachmentStorage:
|
||||
"""Manages temporary storage of email attachments."""
|
||||
|
||||
def __init__(self, expiration_seconds: int = DEFAULT_EXPIRATION_SECONDS):
|
||||
self.expiration_seconds = expiration_seconds
|
||||
self._metadata: Dict[str, Dict] = {}
|
||||
|
||||
def save_attachment(
|
||||
self,
|
||||
base64_data: str,
|
||||
filename: Optional[str] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
) -> SavedAttachment:
|
||||
"""
|
||||
Save an attachment to local disk.
|
||||
|
||||
Args:
|
||||
base64_data: Base64-encoded attachment data
|
||||
filename: Original filename (optional)
|
||||
mime_type: MIME type (optional)
|
||||
|
||||
Returns:
|
||||
SavedAttachment with file_id (UUID) and path (absolute file path)
|
||||
"""
|
||||
_ensure_storage_dir()
|
||||
|
||||
# Generate unique file ID for metadata tracking
|
||||
file_id = str(uuid.uuid4())
|
||||
|
||||
# Decode base64 data
|
||||
try:
|
||||
file_bytes = base64.urlsafe_b64decode(base64_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode base64 attachment data: {e}")
|
||||
raise ValueError(f"Invalid base64 data: {e}")
|
||||
|
||||
# Determine file extension from filename or mime type
|
||||
extension = ""
|
||||
if filename:
|
||||
extension = Path(filename).suffix
|
||||
elif mime_type:
|
||||
# Basic mime type to extension mapping
|
||||
mime_to_ext = {
|
||||
"image/jpeg": ".jpg",
|
||||
"image/png": ".png",
|
||||
"image/gif": ".gif",
|
||||
"application/pdf": ".pdf",
|
||||
"application/zip": ".zip",
|
||||
"text/plain": ".txt",
|
||||
"text/html": ".html",
|
||||
}
|
||||
extension = mime_to_ext.get(mime_type, "")
|
||||
|
||||
# Use original filename if available, with UUID suffix for uniqueness
|
||||
if filename:
|
||||
stem = Path(filename).stem
|
||||
ext = Path(filename).suffix
|
||||
save_name = f"{stem}_{file_id[:8]}{ext}"
|
||||
else:
|
||||
save_name = f"{file_id}{extension}"
|
||||
|
||||
# Save file with restrictive permissions (sensitive email/drive content)
|
||||
file_path = STORAGE_DIR / save_name
|
||||
try:
|
||||
fd = os.open(
|
||||
file_path,
|
||||
os.O_WRONLY | os.O_CREAT | os.O_TRUNC | getattr(os, "O_BINARY", 0),
|
||||
0o600,
|
||||
)
|
||||
try:
|
||||
total_written = 0
|
||||
data_len = len(file_bytes)
|
||||
while total_written < data_len:
|
||||
written = os.write(fd, file_bytes[total_written:])
|
||||
if written == 0:
|
||||
raise OSError(
|
||||
"os.write returned 0 bytes; could not write attachment data"
|
||||
)
|
||||
total_written += written
|
||||
finally:
|
||||
os.close(fd)
|
||||
logger.info(
|
||||
f"Saved attachment file_id={file_id} filename={filename or save_name} "
|
||||
f"({len(file_bytes)} bytes) to {file_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save attachment file_id={file_id} "
|
||||
f"filename={filename or save_name} to {file_path}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Store metadata
|
||||
expires_at = datetime.now() + timedelta(seconds=self.expiration_seconds)
|
||||
self._metadata[file_id] = {
|
||||
"file_path": str(file_path),
|
||||
"filename": filename or f"attachment{extension}",
|
||||
"mime_type": mime_type or "application/octet-stream",
|
||||
"size": len(file_bytes),
|
||||
"created_at": datetime.now(),
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
return SavedAttachment(file_id=file_id, path=str(file_path))
|
||||
|
||||
def get_attachment_path(self, file_id: str) -> Optional[Path]:
|
||||
"""
|
||||
Get the file path for an attachment ID.
|
||||
|
||||
Args:
|
||||
file_id: Unique file ID
|
||||
|
||||
Returns:
|
||||
Path object if file exists and not expired, None otherwise
|
||||
"""
|
||||
if file_id not in self._metadata:
|
||||
logger.warning(f"Attachment {file_id} not found in metadata")
|
||||
return None
|
||||
|
||||
metadata = self._metadata[file_id]
|
||||
file_path = Path(metadata["file_path"])
|
||||
|
||||
# Check if expired
|
||||
if datetime.now() > metadata["expires_at"]:
|
||||
logger.info(f"Attachment {file_id} has expired, cleaning up")
|
||||
self._cleanup_file(file_id)
|
||||
return None
|
||||
|
||||
# Check if file exists
|
||||
if not file_path.exists():
|
||||
logger.warning(f"Attachment file {file_path} does not exist")
|
||||
del self._metadata[file_id]
|
||||
return None
|
||||
|
||||
return file_path
|
||||
|
||||
def get_attachment_metadata(self, file_id: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get metadata for an attachment.
|
||||
|
||||
Args:
|
||||
file_id: Unique file ID
|
||||
|
||||
Returns:
|
||||
Metadata dict if exists and not expired, None otherwise
|
||||
"""
|
||||
if file_id not in self._metadata:
|
||||
return None
|
||||
|
||||
metadata = self._metadata[file_id].copy()
|
||||
|
||||
# Check if expired
|
||||
if datetime.now() > metadata["expires_at"]:
|
||||
self._cleanup_file(file_id)
|
||||
return None
|
||||
|
||||
return metadata
|
||||
|
||||
def _cleanup_file(self, file_id: str) -> None:
|
||||
"""Remove file and metadata."""
|
||||
if file_id in self._metadata:
|
||||
file_path = Path(self._metadata[file_id]["file_path"])
|
||||
try:
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
logger.debug(f"Deleted expired attachment file: {file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete attachment file {file_path}: {e}")
|
||||
del self._metadata[file_id]
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Clean up expired attachments.
|
||||
|
||||
Returns:
|
||||
Number of files cleaned up
|
||||
"""
|
||||
now = datetime.now()
|
||||
expired_ids = [
|
||||
file_id
|
||||
for file_id, metadata in self._metadata.items()
|
||||
if now > metadata["expires_at"]
|
||||
]
|
||||
|
||||
for file_id in expired_ids:
|
||||
self._cleanup_file(file_id)
|
||||
|
||||
return len(expired_ids)
|
||||
|
||||
|
||||
# Global instance
|
||||
_attachment_storage: Optional[AttachmentStorage] = None
|
||||
|
||||
|
||||
def get_attachment_storage() -> AttachmentStorage:
|
||||
"""Get the global attachment storage instance."""
|
||||
global _attachment_storage
|
||||
if _attachment_storage is None:
|
||||
_attachment_storage = AttachmentStorage()
|
||||
return _attachment_storage
|
||||
|
||||
|
||||
def get_attachment_url(file_id: str) -> str:
|
||||
"""
|
||||
Generate a URL for accessing an attachment.
|
||||
|
||||
Args:
|
||||
file_id: Unique file ID
|
||||
|
||||
Returns:
|
||||
Full URL to access the attachment
|
||||
"""
|
||||
from core.config import WORKSPACE_MCP_PORT, WORKSPACE_MCP_BASE_URI
|
||||
|
||||
# Use external URL if set (for reverse proxy scenarios)
|
||||
external_url = os.getenv("WORKSPACE_EXTERNAL_URL")
|
||||
if external_url:
|
||||
base_url = external_url.rstrip("/")
|
||||
else:
|
||||
base_url = f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}"
|
||||
|
||||
return f"{base_url}/attachments/{file_id}"
|
||||
410
core/cli_handler.py
Normal file
410
core/cli_handler.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
CLI Handler for Google Workspace MCP
|
||||
|
||||
This module provides a command-line interface mode for directly invoking
|
||||
MCP tools without running the full server. Designed for use by coding agents
|
||||
(Codex, Claude Code) and command-line users.
|
||||
|
||||
Usage:
|
||||
workspace-mcp --cli # List available tools
|
||||
workspace-mcp --cli list # List available tools
|
||||
workspace-mcp --cli <tool_name> # Run tool (reads JSON args from stdin)
|
||||
workspace-mcp --cli <tool_name> --args '{"key": "value"}' # Run with inline args
|
||||
workspace-mcp --cli <tool_name> --help # Show tool details
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from auth.oauth_config import set_transport_mode
|
||||
from core.tool_registry import get_tool_components
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_registered_tools(server) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all registered tools from the FastMCP server.
|
||||
|
||||
Args:
|
||||
server: The FastMCP server instance
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tool names to their metadata
|
||||
"""
|
||||
tools = {}
|
||||
|
||||
for name, tool in get_tool_components(server).items():
|
||||
tools[name] = {
|
||||
"name": name,
|
||||
"description": getattr(tool, "description", None)
|
||||
or _extract_docstring(tool),
|
||||
"parameters": _extract_parameters(tool),
|
||||
"tool_obj": tool,
|
||||
}
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def _extract_docstring(tool) -> Optional[str]:
|
||||
"""Extract the first meaningful line of a tool's docstring as its description."""
|
||||
fn = getattr(tool, "fn", None) or tool
|
||||
if fn and fn.__doc__:
|
||||
# Get first non-empty line that's not just "Args:" etc.
|
||||
for line in fn.__doc__.strip().split("\n"):
|
||||
line = line.strip()
|
||||
# Skip empty lines and common section headers
|
||||
if line and not line.startswith(
|
||||
("Args:", "Returns:", "Raises:", "Example", "Note:")
|
||||
):
|
||||
return line
|
||||
return None
|
||||
|
||||
|
||||
def _extract_parameters(tool) -> Dict[str, Any]:
|
||||
"""Extract parameter information from a tool."""
|
||||
params = {}
|
||||
|
||||
# Try to get parameters from the tool's schema
|
||||
if hasattr(tool, "parameters"):
|
||||
schema = tool.parameters
|
||||
if isinstance(schema, dict):
|
||||
props = schema.get("properties", {})
|
||||
required = set(schema.get("required", []))
|
||||
for name, prop in props.items():
|
||||
params[name] = {
|
||||
"type": prop.get("type", "any"),
|
||||
"description": prop.get("description", ""),
|
||||
"required": name in required,
|
||||
"default": prop.get("default"),
|
||||
}
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def list_tools(server, output_format: str = "text") -> str:
|
||||
"""
|
||||
List all available tools.
|
||||
|
||||
Args:
|
||||
server: The FastMCP server instance
|
||||
output_format: Output format ("text" or "json")
|
||||
|
||||
Returns:
|
||||
Formatted string listing all tools
|
||||
"""
|
||||
tools = get_registered_tools(server)
|
||||
|
||||
if output_format == "json":
|
||||
# Return JSON format for programmatic use
|
||||
tool_list = []
|
||||
for name, info in sorted(tools.items()):
|
||||
tool_list.append(
|
||||
{
|
||||
"name": name,
|
||||
"description": info["description"],
|
||||
"parameters": info["parameters"],
|
||||
}
|
||||
)
|
||||
return json.dumps({"tools": tool_list}, indent=2)
|
||||
|
||||
# Text format for human reading
|
||||
lines = [
|
||||
f"Available tools ({len(tools)}):",
|
||||
"",
|
||||
]
|
||||
|
||||
# Group tools by service
|
||||
services = {}
|
||||
for name, info in tools.items():
|
||||
# Extract service prefix from tool name
|
||||
prefix = name.split("_")[0] if "_" in name else "other"
|
||||
if prefix not in services:
|
||||
services[prefix] = []
|
||||
services[prefix].append((name, info))
|
||||
|
||||
for service in sorted(services.keys()):
|
||||
lines.append(f" {service.upper()}:")
|
||||
for name, info in sorted(services[service]):
|
||||
desc = info["description"] or "(no description)"
|
||||
# Get first line only and truncate
|
||||
first_line = desc.split("\n")[0].strip()
|
||||
if len(first_line) > 70:
|
||||
first_line = first_line[:67] + "..."
|
||||
lines.append(f" {name}")
|
||||
lines.append(f" {first_line}")
|
||||
lines.append("")
|
||||
|
||||
lines.append("Use --cli <tool_name> --help for detailed tool information")
|
||||
lines.append("Use --cli <tool_name> --args '{...}' to run a tool")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def show_tool_help(server, tool_name: str) -> str:
|
||||
"""
|
||||
Show detailed help for a specific tool.
|
||||
|
||||
Args:
|
||||
server: The FastMCP server instance
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Formatted help string for the tool
|
||||
"""
|
||||
tools = get_registered_tools(server)
|
||||
|
||||
if tool_name not in tools:
|
||||
available = ", ".join(sorted(tools.keys())[:10])
|
||||
return f"Error: Tool '{tool_name}' not found.\n\nAvailable tools include: {available}..."
|
||||
|
||||
tool_info = tools[tool_name]
|
||||
tool_obj = tool_info["tool_obj"]
|
||||
|
||||
# Get full docstring
|
||||
fn = getattr(tool_obj, "fn", None) or tool_obj
|
||||
docstring = fn.__doc__ if fn and fn.__doc__ else "(no documentation)"
|
||||
|
||||
lines = [
|
||||
f"Tool: {tool_name}",
|
||||
"=" * (len(tool_name) + 6),
|
||||
"",
|
||||
docstring,
|
||||
"",
|
||||
"Parameters:",
|
||||
]
|
||||
|
||||
params = tool_info["parameters"]
|
||||
if params:
|
||||
for name, param_info in params.items():
|
||||
req = "(required)" if param_info.get("required") else "(optional)"
|
||||
param_type = param_info.get("type", "any")
|
||||
desc = param_info.get("description", "")
|
||||
default = param_info.get("default")
|
||||
|
||||
lines.append(f" {name}: {param_type} {req}")
|
||||
if desc:
|
||||
lines.append(f" {desc}")
|
||||
if default is not None:
|
||||
lines.append(f" Default: {default}")
|
||||
else:
|
||||
lines.append(" (no parameters)")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"Example usage:",
|
||||
f' workspace-mcp --cli {tool_name} --args \'{{"param": "value"}}\'',
|
||||
"",
|
||||
"Or pipe JSON from stdin:",
|
||||
f' echo \'{{"param": "value"}}\' | workspace-mcp --cli {tool_name}',
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def run_tool(server, tool_name: str, args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Execute a tool with the provided arguments.
|
||||
|
||||
Args:
|
||||
server: The FastMCP server instance
|
||||
tool_name: Name of the tool to execute
|
||||
args: Dictionary of arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
Tool result as a string
|
||||
"""
|
||||
tools = get_registered_tools(server)
|
||||
|
||||
if tool_name not in tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
tool_info = tools[tool_name]
|
||||
tool_obj = tool_info["tool_obj"]
|
||||
|
||||
# Get the actual function to call
|
||||
fn = getattr(tool_obj, "fn", None)
|
||||
if fn is None:
|
||||
raise ValueError(f"Tool '{tool_name}' has no callable function")
|
||||
|
||||
call_args = dict(args)
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
f"[CLI] Executing tool: {tool_name} with args: {list(call_args.keys())}"
|
||||
)
|
||||
|
||||
# Call the tool function
|
||||
if asyncio.iscoroutinefunction(fn):
|
||||
result = await fn(**call_args)
|
||||
else:
|
||||
result = fn(**call_args)
|
||||
|
||||
# Convert result to string if needed
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
else:
|
||||
return json.dumps(result, indent=2, default=str)
|
||||
|
||||
except TypeError as e:
|
||||
# Provide helpful error for missing/invalid arguments
|
||||
error_msg = str(e)
|
||||
params = tool_info["parameters"]
|
||||
required = [n for n, p in params.items() if p.get("required")]
|
||||
|
||||
return (
|
||||
f"Error calling {tool_name}: {error_msg}\n\n"
|
||||
f"Required parameters: {required}\n"
|
||||
f"Provided parameters: {list(call_args.keys())}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[CLI] Error executing {tool_name}: {e}", exc_info=True)
|
||||
return f"Error: {type(e).__name__}: {e}"
|
||||
|
||||
|
||||
def parse_cli_args(args: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse CLI arguments for tool execution.
|
||||
|
||||
Args:
|
||||
args: List of arguments after --cli
|
||||
|
||||
Returns:
|
||||
Dictionary with parsed values:
|
||||
- command: "list", "help", or "run"
|
||||
- tool_name: Name of tool (if applicable)
|
||||
- tool_args: Arguments for the tool (if applicable)
|
||||
- output_format: "text" or "json"
|
||||
"""
|
||||
result = {
|
||||
"command": "list",
|
||||
"tool_name": None,
|
||||
"tool_args": {},
|
||||
"output_format": "text",
|
||||
}
|
||||
|
||||
if not args:
|
||||
return result
|
||||
|
||||
i = 0
|
||||
while i < len(args):
|
||||
arg = args[i]
|
||||
|
||||
if arg in ("list", "-l", "--list"):
|
||||
result["command"] = "list"
|
||||
i += 1
|
||||
elif arg in ("--json", "-j"):
|
||||
result["output_format"] = "json"
|
||||
i += 1
|
||||
elif arg in ("help", "--help", "-h"):
|
||||
# Help command - if tool_name already set, show help for that tool
|
||||
if result["tool_name"]:
|
||||
result["command"] = "help"
|
||||
else:
|
||||
# Check if next arg is a tool name
|
||||
if i + 1 < len(args) and not args[i + 1].startswith("-"):
|
||||
result["tool_name"] = args[i + 1]
|
||||
result["command"] = "help"
|
||||
i += 1
|
||||
else:
|
||||
# No tool specified, show general help
|
||||
result["command"] = "list"
|
||||
i += 1
|
||||
elif arg in ("--args", "-a") and i + 1 < len(args):
|
||||
# Parse inline JSON arguments
|
||||
json_str = args[i + 1]
|
||||
try:
|
||||
result["tool_args"] = json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
# Provide helpful debug info
|
||||
raise ValueError(
|
||||
f"Invalid JSON in --args: {e}\n"
|
||||
f"Received: {repr(json_str)}\n"
|
||||
f"Tip: Try using stdin instead: echo '<json>' | workspace-mcp --cli <tool>"
|
||||
)
|
||||
i += 2
|
||||
elif not arg.startswith("-") and not result["tool_name"]:
|
||||
# First non-flag argument is the tool name
|
||||
result["tool_name"] = arg
|
||||
result["command"] = "run"
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def read_stdin_args() -> Dict[str, Any]:
|
||||
"""
|
||||
Read JSON arguments from stdin if available.
|
||||
|
||||
Returns:
|
||||
Dictionary of arguments or empty dict if stdin is a TTY or no data is provided.
|
||||
"""
|
||||
if sys.stdin.isatty():
|
||||
logger.debug("[CLI] stdin is a TTY; no JSON args will be read from stdin")
|
||||
return {}
|
||||
|
||||
try:
|
||||
stdin_data = sys.stdin.read().strip()
|
||||
if stdin_data:
|
||||
return json.loads(stdin_data)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON from stdin: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
async def handle_cli_mode(server, cli_args: List[str]) -> int:
|
||||
"""
|
||||
Main entry point for CLI mode.
|
||||
|
||||
Args:
|
||||
server: The FastMCP server instance
|
||||
cli_args: Arguments passed after --cli
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success, 1 for error)
|
||||
"""
|
||||
# Set transport mode to "stdio" so OAuth callback server starts when needed
|
||||
# This is required for authentication flow when no cached credentials exist
|
||||
set_transport_mode("stdio")
|
||||
|
||||
try:
|
||||
parsed = parse_cli_args(cli_args)
|
||||
|
||||
if parsed["command"] == "list":
|
||||
output = list_tools(server, parsed["output_format"])
|
||||
print(output)
|
||||
return 0
|
||||
|
||||
if parsed["command"] == "help":
|
||||
output = show_tool_help(server, parsed["tool_name"])
|
||||
print(output)
|
||||
return 0
|
||||
|
||||
if parsed["command"] == "run":
|
||||
# Merge stdin args with inline args (inline takes precedence)
|
||||
args = read_stdin_args()
|
||||
args.update(parsed["tool_args"])
|
||||
|
||||
result = await run_tool(server, parsed["tool_name"], args)
|
||||
print(result)
|
||||
return 0
|
||||
|
||||
# Unknown command
|
||||
print(f"Unknown command: {parsed['command']}")
|
||||
return 1
|
||||
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
except Exception as e:
|
||||
logger.error(f"[CLI] Unexpected error: {e}", exc_info=True)
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
305
core/comments.py
Normal file
305
core/comments.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Core Comments Module
|
||||
|
||||
This module provides reusable comment management functions for Google Workspace applications.
|
||||
All Google Workspace apps (Docs, Sheets, Slides) use the Drive API for comment operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from auth.service_decorator import require_google_service
|
||||
from core.server import server
|
||||
from core.utils import handle_http_errors
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _manage_comment_dispatch(
|
||||
service,
|
||||
app_name: str,
|
||||
file_id: str,
|
||||
action: str,
|
||||
comment_content: Optional[str] = None,
|
||||
comment_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Route comment management actions to the appropriate implementation."""
|
||||
action_lower = action.lower().strip()
|
||||
if action_lower == "create":
|
||||
if not comment_content:
|
||||
raise ValueError("comment_content is required for create action")
|
||||
return await _create_comment_impl(service, app_name, file_id, comment_content)
|
||||
elif action_lower == "reply":
|
||||
if not comment_id or not comment_content:
|
||||
raise ValueError(
|
||||
"comment_id and comment_content are required for reply action"
|
||||
)
|
||||
return await _reply_to_comment_impl(
|
||||
service, app_name, file_id, comment_id, comment_content
|
||||
)
|
||||
elif action_lower == "resolve":
|
||||
if not comment_id:
|
||||
raise ValueError("comment_id is required for resolve action")
|
||||
return await _resolve_comment_impl(service, app_name, file_id, comment_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid action '{action_lower}'. Must be 'create', 'reply', or 'resolve'."
|
||||
)
|
||||
|
||||
|
||||
def create_comment_tools(app_name: str, file_id_param: str):
|
||||
"""
|
||||
Factory function to create comment management tools for a specific Google Workspace app.
|
||||
|
||||
Args:
|
||||
app_name: Name of the app (e.g., "document", "spreadsheet", "presentation")
|
||||
file_id_param: Parameter name for the file ID (e.g., "document_id", "spreadsheet_id", "presentation_id")
|
||||
|
||||
Returns:
|
||||
Dict containing the comment management functions with unique names
|
||||
"""
|
||||
|
||||
# --- Consolidated tools ---
|
||||
list_func_name = f"list_{app_name}_comments"
|
||||
manage_func_name = f"manage_{app_name}_comment"
|
||||
|
||||
if file_id_param == "document_id":
|
||||
|
||||
@require_google_service("drive", "drive_read")
|
||||
@handle_http_errors(list_func_name, service_type="drive")
|
||||
async def list_comments(
|
||||
service, user_google_email: str, document_id: str
|
||||
) -> str:
|
||||
"""List all comments from a Google Document."""
|
||||
return await _read_comments_impl(service, app_name, document_id)
|
||||
|
||||
@require_google_service("drive", "drive_file")
|
||||
@handle_http_errors(manage_func_name, service_type="drive")
|
||||
async def manage_comment(
|
||||
service,
|
||||
user_google_email: str,
|
||||
document_id: str,
|
||||
action: str,
|
||||
comment_content: Optional[str] = None,
|
||||
comment_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Manage comments on a Google Document.
|
||||
|
||||
Actions:
|
||||
- create: Create a new comment. Requires comment_content.
|
||||
- reply: Reply to a comment. Requires comment_id and comment_content.
|
||||
- resolve: Resolve a comment. Requires comment_id.
|
||||
"""
|
||||
return await _manage_comment_dispatch(
|
||||
service, app_name, document_id, action, comment_content, comment_id
|
||||
)
|
||||
|
||||
elif file_id_param == "spreadsheet_id":
|
||||
|
||||
@require_google_service("drive", "drive_read")
|
||||
@handle_http_errors(list_func_name, service_type="drive")
|
||||
async def list_comments(
|
||||
service, user_google_email: str, spreadsheet_id: str
|
||||
) -> str:
|
||||
"""List all comments from a Google Spreadsheet."""
|
||||
return await _read_comments_impl(service, app_name, spreadsheet_id)
|
||||
|
||||
@require_google_service("drive", "drive_file")
|
||||
@handle_http_errors(manage_func_name, service_type="drive")
|
||||
async def manage_comment(
|
||||
service,
|
||||
user_google_email: str,
|
||||
spreadsheet_id: str,
|
||||
action: str,
|
||||
comment_content: Optional[str] = None,
|
||||
comment_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Manage comments on a Google Spreadsheet.
|
||||
|
||||
Actions:
|
||||
- create: Create a new comment. Requires comment_content.
|
||||
- reply: Reply to a comment. Requires comment_id and comment_content.
|
||||
- resolve: Resolve a comment. Requires comment_id.
|
||||
"""
|
||||
return await _manage_comment_dispatch(
|
||||
service, app_name, spreadsheet_id, action, comment_content, comment_id
|
||||
)
|
||||
|
||||
elif file_id_param == "presentation_id":
|
||||
|
||||
@require_google_service("drive", "drive_read")
|
||||
@handle_http_errors(list_func_name, service_type="drive")
|
||||
async def list_comments(
|
||||
service, user_google_email: str, presentation_id: str
|
||||
) -> str:
|
||||
"""List all comments from a Google Presentation."""
|
||||
return await _read_comments_impl(service, app_name, presentation_id)
|
||||
|
||||
@require_google_service("drive", "drive_file")
|
||||
@handle_http_errors(manage_func_name, service_type="drive")
|
||||
async def manage_comment(
|
||||
service,
|
||||
user_google_email: str,
|
||||
presentation_id: str,
|
||||
action: str,
|
||||
comment_content: Optional[str] = None,
|
||||
comment_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Manage comments on a Google Presentation.
|
||||
|
||||
Actions:
|
||||
- create: Create a new comment. Requires comment_content.
|
||||
- reply: Reply to a comment. Requires comment_id and comment_content.
|
||||
- resolve: Resolve a comment. Requires comment_id.
|
||||
"""
|
||||
return await _manage_comment_dispatch(
|
||||
service, app_name, presentation_id, action, comment_content, comment_id
|
||||
)
|
||||
|
||||
list_comments.__name__ = list_func_name
|
||||
manage_comment.__name__ = manage_func_name
|
||||
server.tool()(list_comments)
|
||||
server.tool()(manage_comment)
|
||||
|
||||
return {
|
||||
"list_comments": list_comments,
|
||||
"manage_comment": manage_comment,
|
||||
}
|
||||
|
||||
|
||||
async def _read_comments_impl(service, app_name: str, file_id: str) -> str:
|
||||
"""Implementation for reading comments from any Google Workspace file."""
|
||||
logger.info(f"[read_{app_name}_comments] Reading comments for {app_name} {file_id}")
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
service.comments()
|
||||
.list(
|
||||
fileId=file_id,
|
||||
fields="comments(id,content,author,createdTime,modifiedTime,resolved,quotedFileContent,replies(content,author,id,createdTime,modifiedTime))",
|
||||
)
|
||||
.execute
|
||||
)
|
||||
|
||||
comments = response.get("comments", [])
|
||||
|
||||
if not comments:
|
||||
return f"No comments found in {app_name} {file_id}"
|
||||
|
||||
output = [f"Found {len(comments)} comments in {app_name} {file_id}:\\n"]
|
||||
|
||||
for comment in comments:
|
||||
author = comment.get("author", {}).get("displayName", "Unknown")
|
||||
content = comment.get("content", "")
|
||||
created = comment.get("createdTime", "")
|
||||
resolved = comment.get("resolved", False)
|
||||
comment_id = comment.get("id", "")
|
||||
status = " [RESOLVED]" if resolved else ""
|
||||
|
||||
quoted_text = comment.get("quotedFileContent", {}).get("value", "")
|
||||
|
||||
output.append(f"Comment ID: {comment_id}")
|
||||
output.append(f"Author: {author}")
|
||||
output.append(f"Created: {created}{status}")
|
||||
if quoted_text:
|
||||
output.append(f"Quoted text: {quoted_text}")
|
||||
output.append(f"Content: {content}")
|
||||
|
||||
# Add replies if any
|
||||
replies = comment.get("replies", [])
|
||||
if replies:
|
||||
output.append(f" Replies ({len(replies)}):")
|
||||
for reply in replies:
|
||||
reply_author = reply.get("author", {}).get("displayName", "Unknown")
|
||||
reply_content = reply.get("content", "")
|
||||
reply_created = reply.get("createdTime", "")
|
||||
reply_id = reply.get("id", "")
|
||||
output.append(f" Reply ID: {reply_id}")
|
||||
output.append(f" Author: {reply_author}")
|
||||
output.append(f" Created: {reply_created}")
|
||||
output.append(f" Content: {reply_content}")
|
||||
|
||||
output.append("") # Empty line between comments
|
||||
|
||||
return "\\n".join(output)
|
||||
|
||||
|
||||
async def _create_comment_impl(
|
||||
service, app_name: str, file_id: str, comment_content: str
|
||||
) -> str:
|
||||
"""Implementation for creating a comment on any Google Workspace file."""
|
||||
logger.info(f"[create_{app_name}_comment] Creating comment in {app_name} {file_id}")
|
||||
|
||||
body = {"content": comment_content}
|
||||
|
||||
comment = await asyncio.to_thread(
|
||||
service.comments()
|
||||
.create(
|
||||
fileId=file_id,
|
||||
body=body,
|
||||
fields="id,content,author,createdTime,modifiedTime",
|
||||
)
|
||||
.execute
|
||||
)
|
||||
|
||||
comment_id = comment.get("id", "")
|
||||
author = comment.get("author", {}).get("displayName", "Unknown")
|
||||
created = comment.get("createdTime", "")
|
||||
|
||||
return f"Comment created successfully!\\nComment ID: {comment_id}\\nAuthor: {author}\\nCreated: {created}\\nContent: {comment_content}"
|
||||
|
||||
|
||||
async def _reply_to_comment_impl(
|
||||
service, app_name: str, file_id: str, comment_id: str, reply_content: str
|
||||
) -> str:
|
||||
"""Implementation for replying to a comment on any Google Workspace file."""
|
||||
logger.info(
|
||||
f"[reply_to_{app_name}_comment] Replying to comment {comment_id} in {app_name} {file_id}"
|
||||
)
|
||||
|
||||
body = {"content": reply_content}
|
||||
|
||||
reply = await asyncio.to_thread(
|
||||
service.replies()
|
||||
.create(
|
||||
fileId=file_id,
|
||||
commentId=comment_id,
|
||||
body=body,
|
||||
fields="id,content,author,createdTime,modifiedTime",
|
||||
)
|
||||
.execute
|
||||
)
|
||||
|
||||
reply_id = reply.get("id", "")
|
||||
author = reply.get("author", {}).get("displayName", "Unknown")
|
||||
created = reply.get("createdTime", "")
|
||||
|
||||
return f"Reply posted successfully!\\nReply ID: {reply_id}\\nAuthor: {author}\\nCreated: {created}\\nContent: {reply_content}"
|
||||
|
||||
|
||||
async def _resolve_comment_impl(
|
||||
service, app_name: str, file_id: str, comment_id: str
|
||||
) -> str:
|
||||
"""Implementation for resolving a comment on any Google Workspace file."""
|
||||
logger.info(
|
||||
f"[resolve_{app_name}_comment] Resolving comment {comment_id} in {app_name} {file_id}"
|
||||
)
|
||||
|
||||
body = {"content": "This comment has been resolved.", "action": "resolve"}
|
||||
|
||||
reply = await asyncio.to_thread(
|
||||
service.replies()
|
||||
.create(
|
||||
fileId=file_id,
|
||||
commentId=comment_id,
|
||||
body=body,
|
||||
fields="id,content,author,createdTime,modifiedTime",
|
||||
)
|
||||
.execute
|
||||
)
|
||||
|
||||
reply_id = reply.get("id", "")
|
||||
author = reply.get("author", {}).get("displayName", "Unknown")
|
||||
created = reply.get("createdTime", "")
|
||||
|
||||
return f"Comment {comment_id} has been resolved successfully.\\nResolve reply ID: {reply_id}\\nAuthor: {author}\\nCreated: {created}"
|
||||
37
core/config.py
Normal file
37
core/config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Shared configuration for Google Workspace MCP server.
|
||||
This module holds configuration values that need to be shared across modules
|
||||
to avoid circular imports.
|
||||
|
||||
NOTE: OAuth configuration has been moved to auth.oauth_config for centralization.
|
||||
This module now imports from there for backward compatibility.
|
||||
"""
|
||||
|
||||
import os
|
||||
from auth.oauth_config import (
|
||||
get_oauth_base_url,
|
||||
get_oauth_redirect_uri,
|
||||
set_transport_mode,
|
||||
get_transport_mode,
|
||||
is_oauth21_enabled,
|
||||
)
|
||||
|
||||
# Server configuration
|
||||
WORKSPACE_MCP_PORT = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)))
|
||||
WORKSPACE_MCP_BASE_URI = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
|
||||
|
||||
# Disable USER_GOOGLE_EMAIL in OAuth 2.1 multi-user mode
|
||||
USER_GOOGLE_EMAIL = (
|
||||
None if is_oauth21_enabled() else os.getenv("USER_GOOGLE_EMAIL", None)
|
||||
)
|
||||
|
||||
# Re-export OAuth functions for backward compatibility
|
||||
__all__ = [
|
||||
"WORKSPACE_MCP_PORT",
|
||||
"WORKSPACE_MCP_BASE_URI",
|
||||
"USER_GOOGLE_EMAIL",
|
||||
"get_oauth_base_url",
|
||||
"get_oauth_redirect_uri",
|
||||
"set_transport_mode",
|
||||
"get_transport_mode",
|
||||
]
|
||||
43
core/context.py
Normal file
43
core/context.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# core/context.py
|
||||
import contextvars
|
||||
from typing import Optional
|
||||
|
||||
# Context variable to hold injected credentials for the life of a single request.
|
||||
_injected_oauth_credentials = contextvars.ContextVar(
|
||||
"injected_oauth_credentials", default=None
|
||||
)
|
||||
|
||||
# Context variable to hold FastMCP session ID for the life of a single request.
|
||||
_fastmcp_session_id = contextvars.ContextVar("fastmcp_session_id", default=None)
|
||||
|
||||
|
||||
def get_injected_oauth_credentials():
|
||||
"""
|
||||
Retrieve injected OAuth credentials for the current request context.
|
||||
This is called by the authentication layer to check for request-scoped credentials.
|
||||
"""
|
||||
return _injected_oauth_credentials.get()
|
||||
|
||||
|
||||
def set_injected_oauth_credentials(credentials: Optional[dict]):
|
||||
"""
|
||||
Set or clear the injected OAuth credentials for the current request context.
|
||||
This is called by the service decorator.
|
||||
"""
|
||||
_injected_oauth_credentials.set(credentials)
|
||||
|
||||
|
||||
def get_fastmcp_session_id() -> Optional[str]:
|
||||
"""
|
||||
Retrieve the FastMCP session ID for the current request context.
|
||||
This is called by authentication layer to get the current session.
|
||||
"""
|
||||
return _fastmcp_session_id.get()
|
||||
|
||||
|
||||
def set_fastmcp_session_id(session_id: Optional[str]):
|
||||
"""
|
||||
Set or clear the FastMCP session ID for the current request context.
|
||||
This is called when a FastMCP request starts.
|
||||
"""
|
||||
_fastmcp_session_id.set(session_id)
|
||||
207
core/log_formatter.py
Normal file
207
core/log_formatter.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Enhanced Log Formatter for Google Workspace MCP
|
||||
|
||||
Provides visually appealing log formatting with emojis and consistent styling
|
||||
to match the safe_print output format.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
class EnhancedLogFormatter(logging.Formatter):
|
||||
"""Custom log formatter that adds ASCII prefixes and visual enhancements to log messages."""
|
||||
|
||||
# Color codes for terminals that support ANSI colors
|
||||
COLORS = {
|
||||
"DEBUG": "\033[36m", # Cyan
|
||||
"INFO": "\033[32m", # Green
|
||||
"WARNING": "\033[33m", # Yellow
|
||||
"ERROR": "\033[31m", # Red
|
||||
"CRITICAL": "\033[35m", # Magenta
|
||||
"RESET": "\033[0m", # Reset
|
||||
}
|
||||
|
||||
def __init__(self, use_colors: bool = True, *args, **kwargs):
|
||||
"""
|
||||
Initialize the emoji log formatter.
|
||||
|
||||
Args:
|
||||
use_colors: Whether to use ANSI color codes (default: True)
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.use_colors = use_colors
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Format the log record with ASCII prefixes and enhanced styling."""
|
||||
# Get the appropriate ASCII prefix for the service
|
||||
service_prefix = self._get_ascii_prefix(record.name, record.levelname)
|
||||
|
||||
# Format the message with enhanced styling
|
||||
formatted_msg = self._enhance_message(record.getMessage())
|
||||
|
||||
# Build the formatted log entry
|
||||
if self.use_colors:
|
||||
color = self.COLORS.get(record.levelname, "")
|
||||
reset = self.COLORS["RESET"]
|
||||
return f"{service_prefix} {color}{formatted_msg}{reset}"
|
||||
else:
|
||||
return f"{service_prefix} {formatted_msg}"
|
||||
|
||||
def _get_ascii_prefix(self, logger_name: str, level_name: str) -> str:
|
||||
"""Get ASCII-safe prefix for Windows compatibility."""
|
||||
# ASCII-safe prefixes for different services
|
||||
ascii_prefixes = {
|
||||
"core.tool_tier_loader": "[TOOLS]",
|
||||
"core.tool_registry": "[REGISTRY]",
|
||||
"auth.scopes": "[AUTH]",
|
||||
"core.utils": "[UTILS]",
|
||||
"auth.google_auth": "[OAUTH]",
|
||||
"auth.credential_store": "[CREDS]",
|
||||
"gcalendar.calendar_tools": "[CALENDAR]",
|
||||
"gdrive.drive_tools": "[DRIVE]",
|
||||
"gmail.gmail_tools": "[GMAIL]",
|
||||
"gdocs.docs_tools": "[DOCS]",
|
||||
"gsheets.sheets_tools": "[SHEETS]",
|
||||
"gchat.chat_tools": "[CHAT]",
|
||||
"gforms.forms_tools": "[FORMS]",
|
||||
"gslides.slides_tools": "[SLIDES]",
|
||||
"gtasks.tasks_tools": "[TASKS]",
|
||||
"gsearch.search_tools": "[SEARCH]",
|
||||
}
|
||||
|
||||
return ascii_prefixes.get(logger_name, f"[{level_name}]")
|
||||
|
||||
def _enhance_message(self, message: str) -> str:
|
||||
"""Enhance the log message with better formatting."""
|
||||
# Handle common patterns for better visual appeal
|
||||
|
||||
# Tool tier loading messages
|
||||
if "resolved to" in message and "tools across" in message:
|
||||
# Extract numbers and service names for better formatting
|
||||
pattern = (
|
||||
r"Tier '(\w+)' resolved to (\d+) tools across (\d+) services: (.+)"
|
||||
)
|
||||
match = re.search(pattern, message)
|
||||
if match:
|
||||
tier, tool_count, service_count, services = match.groups()
|
||||
return f"Tool tier '{tier}' loaded: {tool_count} tools across {service_count} services [{services}]"
|
||||
|
||||
# Configuration loading messages
|
||||
if "Loaded tool tiers configuration from" in message:
|
||||
path = message.split("from ")[-1]
|
||||
return f"Configuration loaded from {path}"
|
||||
|
||||
# Tool filtering messages
|
||||
if "Tool tier filtering" in message:
|
||||
pattern = r"removed (\d+) tools, (\d+) enabled"
|
||||
match = re.search(pattern, message)
|
||||
if match:
|
||||
removed, enabled = match.groups()
|
||||
return f"Tool filtering complete: {enabled} tools enabled ({removed} filtered out)"
|
||||
|
||||
# Enabled tools messages
|
||||
if "Enabled tools set for scope management" in message:
|
||||
tools = message.split(": ")[-1]
|
||||
return f"Scope management configured for tools: {tools}"
|
||||
|
||||
# Credentials directory messages
|
||||
if "Credentials directory permissions check passed" in message:
|
||||
path = message.split(": ")[-1]
|
||||
return f"Credentials directory verified: {path}"
|
||||
|
||||
# If no specific pattern matches, return the original message
|
||||
return message
|
||||
|
||||
|
||||
def setup_enhanced_logging(
|
||||
log_level: int = logging.INFO, use_colors: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Set up enhanced logging with ASCII prefix formatter for the entire application.
|
||||
|
||||
Args:
|
||||
log_level: The logging level to use (default: INFO)
|
||||
use_colors: Whether to use ANSI colors (default: True)
|
||||
"""
|
||||
# Create the enhanced formatter
|
||||
formatter = EnhancedLogFormatter(use_colors=use_colors)
|
||||
|
||||
# Get the root logger
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
# Update existing console handlers
|
||||
for handler in root_logger.handlers:
|
||||
if isinstance(handler, logging.StreamHandler) and handler.stream.name in [
|
||||
"<stderr>",
|
||||
"<stdout>",
|
||||
]:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# If no console handler exists, create one
|
||||
console_handlers = [
|
||||
h
|
||||
for h in root_logger.handlers
|
||||
if isinstance(h, logging.StreamHandler)
|
||||
and h.stream.name in ["<stderr>", "<stdout>"]
|
||||
]
|
||||
|
||||
if not console_handlers:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(log_level)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
|
||||
def configure_file_logging(logger_name: str = None) -> bool:
|
||||
"""
|
||||
Configure file logging based on stateless mode setting.
|
||||
|
||||
In stateless mode, file logging is completely disabled to avoid filesystem writes.
|
||||
In normal mode, sets up detailed file logging to 'mcp_server_debug.log'.
|
||||
|
||||
Args:
|
||||
logger_name: Optional name for the logger (defaults to root logger)
|
||||
|
||||
Returns:
|
||||
bool: True if file logging was configured, False if skipped (stateless mode)
|
||||
"""
|
||||
# Check if stateless mode is enabled
|
||||
stateless_mode = (
|
||||
os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
|
||||
)
|
||||
|
||||
if stateless_mode:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.debug("File logging disabled in stateless mode")
|
||||
return False
|
||||
|
||||
# Configure file logging for normal mode
|
||||
try:
|
||||
target_logger = logging.getLogger(logger_name)
|
||||
log_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# Go up one level since we're in core/ subdirectory
|
||||
log_file_dir = os.path.dirname(log_file_dir)
|
||||
log_file_path = os.path.join(log_file_dir, "mcp_server_debug.log")
|
||||
|
||||
file_handler = logging.FileHandler(log_file_path, mode="a")
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
|
||||
file_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(process)d - %(threadName)s "
|
||||
"[%(module)s.%(funcName)s:%(lineno)d] - %(message)s"
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
target_logger.addHandler(file_handler)
|
||||
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.debug(f"Detailed file logging configured to: {log_file_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
sys.stderr.write(
|
||||
f"CRITICAL: Failed to set up file logging to '{log_file_path}': {e}\n"
|
||||
)
|
||||
return False
|
||||
620
core/server.py
Normal file
620
core/server.py
Normal file
@@ -0,0 +1,620 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from importlib import metadata
|
||||
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
|
||||
from starlette.applications import Starlette
|
||||
from starlette.datastructures import MutableHeaders
|
||||
from starlette.types import Scope, Receive, Send
|
||||
from starlette.requests import Request
|
||||
from starlette.middleware import Middleware
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.auth.providers.google import GoogleProvider
|
||||
|
||||
from auth.oauth21_session_store import get_oauth21_session_store, set_auth_provider
|
||||
from auth.google_auth import handle_auth_callback, start_auth_flow, check_client_secrets
|
||||
from auth.oauth_config import is_oauth21_enabled, is_external_oauth21_provider
|
||||
from auth.mcp_session_middleware import MCPSessionMiddleware
|
||||
from auth.oauth_responses import (
|
||||
create_error_response,
|
||||
create_success_response,
|
||||
create_server_error_response,
|
||||
)
|
||||
from auth.auth_info_middleware import AuthInfoMiddleware
|
||||
from auth.scopes import SCOPES, get_current_scopes # noqa
|
||||
from core.config import (
|
||||
USER_GOOGLE_EMAIL,
|
||||
get_transport_mode,
|
||||
set_transport_mode as _set_transport_mode,
|
||||
get_oauth_redirect_uri as get_oauth_redirect_uri_for_current_mode,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_auth_provider: Optional[GoogleProvider] = None
|
||||
_legacy_callback_registered = False
|
||||
|
||||
session_middleware = Middleware(MCPSessionMiddleware)
|
||||
|
||||
|
||||
class WellKnownCacheControlMiddleware:
|
||||
"""Force no-cache headers for OAuth well-known discovery endpoints."""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
path = scope.get("path", "")
|
||||
is_oauth_well_known = (
|
||||
path == "/.well-known/oauth-authorization-server"
|
||||
or path.startswith("/.well-known/oauth-authorization-server/")
|
||||
or path == "/.well-known/oauth-protected-resource"
|
||||
or path.startswith("/.well-known/oauth-protected-resource/")
|
||||
)
|
||||
if not is_oauth_well_known:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
async def send_with_no_cache_headers(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = MutableHeaders(raw=message.setdefault("headers", []))
|
||||
headers["Cache-Control"] = "no-store, must-revalidate"
|
||||
headers["ETag"] = f'"{_compute_scope_fingerprint()}"'
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_with_no_cache_headers)
|
||||
|
||||
|
||||
well_known_cache_control_middleware = Middleware(WellKnownCacheControlMiddleware)
|
||||
|
||||
|
||||
def _compute_scope_fingerprint() -> str:
|
||||
"""Compute a short hash of the current scope configuration for cache-busting."""
|
||||
scopes_str = ",".join(sorted(get_current_scopes()))
|
||||
return hashlib.sha256(scopes_str.encode()).hexdigest()[:12]
|
||||
|
||||
|
||||
# Custom FastMCP that adds secure middleware stack for OAuth 2.1
|
||||
class SecureFastMCP(FastMCP):
|
||||
def http_app(self, **kwargs) -> "Starlette":
|
||||
"""Override to add secure middleware stack for OAuth 2.1."""
|
||||
app = super().http_app(**kwargs)
|
||||
|
||||
# Add middleware in order (first added = outermost layer)
|
||||
app.user_middleware.insert(0, well_known_cache_control_middleware)
|
||||
|
||||
# Session Management - extracts session info for MCP context
|
||||
app.user_middleware.insert(1, session_middleware)
|
||||
|
||||
# Rebuild middleware stack
|
||||
app.middleware_stack = app.build_middleware_stack()
|
||||
logger.info("Added middleware stack: WellKnownCacheControl, Session Management")
|
||||
return app
|
||||
|
||||
|
||||
# Build server instructions with user email context for single-user mode
|
||||
_server_instructions = None
|
||||
if USER_GOOGLE_EMAIL:
|
||||
_server_instructions = f"""Connected Google account: {USER_GOOGLE_EMAIL}
|
||||
|
||||
When using Google Workspace tools, always use `{USER_GOOGLE_EMAIL}` as the `user_google_email` parameter. Do not ask the user for their email address."""
|
||||
logger.info(f"Server instructions configured for user: {USER_GOOGLE_EMAIL}")
|
||||
|
||||
server = SecureFastMCP(
|
||||
name="google_workspace",
|
||||
auth=None,
|
||||
instructions=_server_instructions,
|
||||
)
|
||||
|
||||
# Add the AuthInfo middleware to inject authentication into FastMCP context
|
||||
auth_info_middleware = AuthInfoMiddleware()
|
||||
server.add_middleware(auth_info_middleware)
|
||||
|
||||
|
||||
def _parse_bool_env(value: str) -> bool:
|
||||
"""Parse environment variable string to boolean."""
|
||||
return value.lower() in ("1", "true", "yes", "on")
|
||||
|
||||
|
||||
def set_transport_mode(mode: str):
|
||||
"""Sets the transport mode for the server."""
|
||||
_set_transport_mode(mode)
|
||||
logger.info(f"Transport: {mode}")
|
||||
|
||||
|
||||
def _ensure_legacy_callback_route() -> None:
|
||||
global _legacy_callback_registered
|
||||
if _legacy_callback_registered:
|
||||
return
|
||||
server.custom_route("/oauth2callback", methods=["GET"])(legacy_oauth2_callback)
|
||||
_legacy_callback_registered = True
|
||||
|
||||
|
||||
def configure_server_for_http():
|
||||
"""
|
||||
Configures the authentication provider for HTTP transport.
|
||||
This must be called BEFORE server.run().
|
||||
"""
|
||||
global _auth_provider
|
||||
|
||||
transport_mode = get_transport_mode()
|
||||
|
||||
if transport_mode != "streamable-http":
|
||||
return
|
||||
|
||||
# Use centralized OAuth configuration
|
||||
from auth.oauth_config import get_oauth_config
|
||||
|
||||
config = get_oauth_config()
|
||||
|
||||
# Check if OAuth 2.1 is enabled via centralized config
|
||||
oauth21_enabled = config.is_oauth21_enabled()
|
||||
|
||||
if oauth21_enabled:
|
||||
if not config.is_configured():
|
||||
logger.warning("OAuth 2.1 enabled but OAuth credentials not configured")
|
||||
return
|
||||
|
||||
def validate_and_derive_jwt_key(
|
||||
jwt_signing_key_override: str | None, client_secret: str
|
||||
) -> bytes:
|
||||
"""Validate JWT signing key override and derive the final JWT key."""
|
||||
if jwt_signing_key_override:
|
||||
if len(jwt_signing_key_override) < 12:
|
||||
logger.warning(
|
||||
"OAuth 2.1: FASTMCP_SERVER_AUTH_GOOGLE_JWT_SIGNING_KEY is less than 12 characters; "
|
||||
"use a longer secret to improve key derivation strength."
|
||||
)
|
||||
return derive_jwt_key(
|
||||
low_entropy_material=jwt_signing_key_override,
|
||||
salt="fastmcp-jwt-signing-key",
|
||||
)
|
||||
else:
|
||||
return derive_jwt_key(
|
||||
high_entropy_material=client_secret,
|
||||
salt="fastmcp-jwt-signing-key",
|
||||
)
|
||||
|
||||
try:
|
||||
# Import common dependencies for storage backends
|
||||
from key_value.aio.wrappers.encryption import FernetEncryptionWrapper
|
||||
from cryptography.fernet import Fernet
|
||||
from fastmcp.server.auth.jwt_issuer import derive_jwt_key
|
||||
|
||||
required_scopes: List[str] = sorted(get_current_scopes())
|
||||
|
||||
client_storage = None
|
||||
jwt_signing_key_override = (
|
||||
os.getenv("FASTMCP_SERVER_AUTH_GOOGLE_JWT_SIGNING_KEY", "").strip()
|
||||
or None
|
||||
)
|
||||
storage_backend = (
|
||||
os.getenv("WORKSPACE_MCP_OAUTH_PROXY_STORAGE_BACKEND", "")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
valkey_host = os.getenv("WORKSPACE_MCP_OAUTH_PROXY_VALKEY_HOST", "").strip()
|
||||
|
||||
# Determine storage backend: valkey, disk, memory (default)
|
||||
use_valkey = storage_backend == "valkey" or bool(valkey_host)
|
||||
use_disk = storage_backend == "disk"
|
||||
|
||||
if use_valkey:
|
||||
try:
|
||||
from key_value.aio.stores.valkey import ValkeyStore
|
||||
|
||||
valkey_port_raw = os.getenv(
|
||||
"WORKSPACE_MCP_OAUTH_PROXY_VALKEY_PORT", "6379"
|
||||
).strip()
|
||||
valkey_db_raw = os.getenv(
|
||||
"WORKSPACE_MCP_OAUTH_PROXY_VALKEY_DB", "0"
|
||||
).strip()
|
||||
|
||||
valkey_port = int(valkey_port_raw)
|
||||
valkey_db = int(valkey_db_raw)
|
||||
valkey_use_tls_raw = os.getenv(
|
||||
"WORKSPACE_MCP_OAUTH_PROXY_VALKEY_USE_TLS", ""
|
||||
).strip()
|
||||
valkey_use_tls = (
|
||||
_parse_bool_env(valkey_use_tls_raw)
|
||||
if valkey_use_tls_raw
|
||||
else valkey_port == 6380
|
||||
)
|
||||
|
||||
valkey_request_timeout_ms_raw = os.getenv(
|
||||
"WORKSPACE_MCP_OAUTH_PROXY_VALKEY_REQUEST_TIMEOUT_MS", ""
|
||||
).strip()
|
||||
valkey_connection_timeout_ms_raw = os.getenv(
|
||||
"WORKSPACE_MCP_OAUTH_PROXY_VALKEY_CONNECTION_TIMEOUT_MS", ""
|
||||
).strip()
|
||||
|
||||
valkey_request_timeout_ms = (
|
||||
int(valkey_request_timeout_ms_raw)
|
||||
if valkey_request_timeout_ms_raw
|
||||
else None
|
||||
)
|
||||
valkey_connection_timeout_ms = (
|
||||
int(valkey_connection_timeout_ms_raw)
|
||||
if valkey_connection_timeout_ms_raw
|
||||
else None
|
||||
)
|
||||
|
||||
valkey_username = (
|
||||
os.getenv(
|
||||
"WORKSPACE_MCP_OAUTH_PROXY_VALKEY_USERNAME", ""
|
||||
).strip()
|
||||
or None
|
||||
)
|
||||
valkey_password = (
|
||||
os.getenv(
|
||||
"WORKSPACE_MCP_OAUTH_PROXY_VALKEY_PASSWORD", ""
|
||||
).strip()
|
||||
or None
|
||||
)
|
||||
|
||||
if not valkey_host:
|
||||
valkey_host = "localhost"
|
||||
|
||||
client_storage = ValkeyStore(
|
||||
host=valkey_host,
|
||||
port=valkey_port,
|
||||
db=valkey_db,
|
||||
username=valkey_username,
|
||||
password=valkey_password,
|
||||
)
|
||||
|
||||
# Configure TLS and timeouts on the underlying Glide client config.
|
||||
# ValkeyStore currently doesn't expose these settings directly.
|
||||
glide_config = getattr(client_storage, "_client_config", None)
|
||||
if glide_config is not None:
|
||||
glide_config.use_tls = valkey_use_tls
|
||||
|
||||
is_remote_host = valkey_host not in {"localhost", "127.0.0.1"}
|
||||
if valkey_request_timeout_ms is None and (
|
||||
valkey_use_tls or is_remote_host
|
||||
):
|
||||
# Glide defaults to 250ms if unset; increase for remote/TLS endpoints.
|
||||
valkey_request_timeout_ms = 5000
|
||||
if valkey_request_timeout_ms is not None:
|
||||
glide_config.request_timeout = valkey_request_timeout_ms
|
||||
|
||||
if valkey_connection_timeout_ms is None and (
|
||||
valkey_use_tls or is_remote_host
|
||||
):
|
||||
valkey_connection_timeout_ms = 10000
|
||||
if valkey_connection_timeout_ms is not None:
|
||||
from glide_shared.config import (
|
||||
AdvancedGlideClientConfiguration,
|
||||
)
|
||||
|
||||
glide_config.advanced_config = (
|
||||
AdvancedGlideClientConfiguration(
|
||||
connection_timeout=valkey_connection_timeout_ms
|
||||
)
|
||||
)
|
||||
|
||||
jwt_signing_key = validate_and_derive_jwt_key(
|
||||
jwt_signing_key_override, config.client_secret
|
||||
)
|
||||
|
||||
storage_encryption_key = derive_jwt_key(
|
||||
high_entropy_material=jwt_signing_key.decode(),
|
||||
salt="fastmcp-storage-encryption-key",
|
||||
)
|
||||
|
||||
client_storage = FernetEncryptionWrapper(
|
||||
key_value=client_storage,
|
||||
fernet=Fernet(key=storage_encryption_key),
|
||||
)
|
||||
logger.info(
|
||||
"OAuth 2.1: Using ValkeyStore for FastMCP OAuth proxy client_storage (host=%s, port=%s, db=%s, tls=%s)",
|
||||
valkey_host,
|
||||
valkey_port,
|
||||
valkey_db,
|
||||
valkey_use_tls,
|
||||
)
|
||||
if valkey_request_timeout_ms is not None:
|
||||
logger.info(
|
||||
"OAuth 2.1: Valkey request timeout set to %sms",
|
||||
valkey_request_timeout_ms,
|
||||
)
|
||||
if valkey_connection_timeout_ms is not None:
|
||||
logger.info(
|
||||
"OAuth 2.1: Valkey connection timeout set to %sms",
|
||||
valkey_connection_timeout_ms,
|
||||
)
|
||||
logger.info(
|
||||
"OAuth 2.1: Applied Fernet encryption wrapper to Valkey client_storage (key derived from FASTMCP_SERVER_AUTH_GOOGLE_JWT_SIGNING_KEY or GOOGLE_OAUTH_CLIENT_SECRET)."
|
||||
)
|
||||
except ImportError as exc:
|
||||
logger.warning(
|
||||
"OAuth 2.1: Valkey client_storage requested but Valkey dependencies are not installed (%s). "
|
||||
"Install 'workspace-mcp[valkey]' (or 'py-key-value-aio[valkey]', which includes 'valkey-glide') "
|
||||
"or unset WORKSPACE_MCP_OAUTH_PROXY_STORAGE_BACKEND/WORKSPACE_MCP_OAUTH_PROXY_VALKEY_HOST.",
|
||||
exc,
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"OAuth 2.1: Invalid Valkey configuration; falling back to default storage (%s).",
|
||||
exc,
|
||||
)
|
||||
elif use_disk:
|
||||
try:
|
||||
from key_value.aio.stores.filetree import FileTreeStore
|
||||
|
||||
disk_directory = os.getenv(
|
||||
"WORKSPACE_MCP_OAUTH_PROXY_DISK_DIRECTORY", ""
|
||||
).strip()
|
||||
if not disk_directory:
|
||||
# Default to FASTMCP_HOME/oauth-proxy or ~/.fastmcp/oauth-proxy
|
||||
fastmcp_home = os.getenv("FASTMCP_HOME", "").strip()
|
||||
if fastmcp_home:
|
||||
disk_directory = os.path.join(fastmcp_home, "oauth-proxy")
|
||||
else:
|
||||
disk_directory = os.path.expanduser(
|
||||
"~/.fastmcp/oauth-proxy"
|
||||
)
|
||||
|
||||
client_storage = FileTreeStore(data_directory=disk_directory)
|
||||
|
||||
jwt_signing_key = validate_and_derive_jwt_key(
|
||||
jwt_signing_key_override, config.client_secret
|
||||
)
|
||||
|
||||
storage_encryption_key = derive_jwt_key(
|
||||
high_entropy_material=jwt_signing_key.decode(),
|
||||
salt="fastmcp-storage-encryption-key",
|
||||
)
|
||||
|
||||
client_storage = FernetEncryptionWrapper(
|
||||
key_value=client_storage,
|
||||
fernet=Fernet(key=storage_encryption_key),
|
||||
)
|
||||
logger.info(
|
||||
"OAuth 2.1: Using FileTreeStore for FastMCP OAuth proxy client_storage (directory=%s)",
|
||||
disk_directory,
|
||||
)
|
||||
except ImportError as exc:
|
||||
logger.warning(
|
||||
"OAuth 2.1: Disk storage requested but dependencies not available (%s). "
|
||||
"Falling back to default storage.",
|
||||
exc,
|
||||
)
|
||||
elif storage_backend == "memory":
|
||||
from key_value.aio.stores.memory import MemoryStore
|
||||
|
||||
client_storage = MemoryStore()
|
||||
logger.info(
|
||||
"OAuth 2.1: Using MemoryStore for FastMCP OAuth proxy client_storage"
|
||||
)
|
||||
# else: client_storage remains None, FastMCP uses its default
|
||||
|
||||
# Ensure JWT signing key is always derived for all storage backends
|
||||
if "jwt_signing_key" not in locals():
|
||||
jwt_signing_key = validate_and_derive_jwt_key(
|
||||
jwt_signing_key_override, config.client_secret
|
||||
)
|
||||
|
||||
# Check if external OAuth provider is configured
|
||||
if config.is_external_oauth21_provider():
|
||||
# External OAuth mode: use custom provider that handles ya29.* access tokens
|
||||
from auth.external_oauth_provider import ExternalOAuthProvider
|
||||
|
||||
provider = ExternalOAuthProvider(
|
||||
client_id=config.client_id,
|
||||
client_secret=config.client_secret,
|
||||
base_url=config.get_oauth_base_url(),
|
||||
redirect_path=config.redirect_path,
|
||||
required_scopes=required_scopes,
|
||||
resource_server_url=config.get_oauth_base_url(),
|
||||
)
|
||||
server.auth = provider
|
||||
|
||||
logger.info("OAuth 2.1 enabled with EXTERNAL provider mode")
|
||||
logger.info(
|
||||
"Expecting Authorization bearer tokens in tool call headers"
|
||||
)
|
||||
logger.info(
|
||||
"Protected resource metadata points to Google's authorization server"
|
||||
)
|
||||
else:
|
||||
# Standard OAuth 2.1 mode: use FastMCP's GoogleProvider
|
||||
provider = GoogleProvider(
|
||||
client_id=config.client_id,
|
||||
client_secret=config.client_secret,
|
||||
base_url=config.get_oauth_base_url(),
|
||||
redirect_path=config.redirect_path,
|
||||
required_scopes=required_scopes,
|
||||
client_storage=client_storage,
|
||||
jwt_signing_key=jwt_signing_key,
|
||||
)
|
||||
# Enable protocol-level auth
|
||||
server.auth = provider
|
||||
logger.info(
|
||||
"OAuth 2.1 enabled using FastMCP GoogleProvider with protocol-level auth"
|
||||
)
|
||||
|
||||
# Always set auth provider for token validation in middleware
|
||||
set_auth_provider(provider)
|
||||
_auth_provider = provider
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to initialize FastMCP GoogleProvider: %s", exc, exc_info=True
|
||||
)
|
||||
raise
|
||||
else:
|
||||
logger.info("OAuth 2.0 mode - Server will use legacy authentication.")
|
||||
server.auth = None
|
||||
_auth_provider = None
|
||||
set_auth_provider(None)
|
||||
_ensure_legacy_callback_route()
|
||||
|
||||
|
||||
def get_auth_provider() -> Optional[GoogleProvider]:
|
||||
"""Gets the global authentication provider instance."""
|
||||
return _auth_provider
|
||||
|
||||
|
||||
@server.custom_route("/", methods=["GET"])
|
||||
@server.custom_route("/health", methods=["GET"])
|
||||
async def health_check(request: Request):
|
||||
try:
|
||||
version = metadata.version("workspace-mcp")
|
||||
except metadata.PackageNotFoundError:
|
||||
version = "dev"
|
||||
return JSONResponse(
|
||||
{
|
||||
"status": "healthy",
|
||||
"service": "workspace-mcp",
|
||||
"version": version,
|
||||
"transport": get_transport_mode(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@server.custom_route("/attachments/{file_id}", methods=["GET"])
|
||||
async def serve_attachment(request: Request):
|
||||
"""Serve a stored attachment file."""
|
||||
from core.attachment_storage import get_attachment_storage
|
||||
|
||||
file_id = request.path_params["file_id"]
|
||||
storage = get_attachment_storage()
|
||||
metadata = storage.get_attachment_metadata(file_id)
|
||||
|
||||
if not metadata:
|
||||
return JSONResponse(
|
||||
{"error": "Attachment not found or expired"}, status_code=404
|
||||
)
|
||||
|
||||
file_path = storage.get_attachment_path(file_id)
|
||||
if not file_path:
|
||||
return JSONResponse({"error": "Attachment file not found"}, status_code=404)
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
filename=metadata["filename"],
|
||||
media_type=metadata["mime_type"],
|
||||
)
|
||||
|
||||
|
||||
async def legacy_oauth2_callback(request: Request) -> HTMLResponse:
|
||||
state = request.query_params.get("state")
|
||||
code = request.query_params.get("code")
|
||||
error = request.query_params.get("error")
|
||||
|
||||
if error:
|
||||
msg = (
|
||||
f"Authentication failed: Google returned an error: {error}. State: {state}."
|
||||
)
|
||||
logger.error(msg)
|
||||
return create_error_response(msg)
|
||||
|
||||
if not code:
|
||||
msg = "Authentication failed: No authorization code received from Google."
|
||||
logger.error(msg)
|
||||
return create_error_response(msg)
|
||||
|
||||
try:
|
||||
error_message = check_client_secrets()
|
||||
if error_message:
|
||||
return create_server_error_response(error_message)
|
||||
|
||||
logger.info("OAuth callback: Received authorization code.")
|
||||
|
||||
mcp_session_id = None
|
||||
if hasattr(request, "state") and hasattr(request.state, "session_id"):
|
||||
mcp_session_id = request.state.session_id
|
||||
|
||||
verified_user_id, credentials = handle_auth_callback(
|
||||
scopes=get_current_scopes(),
|
||||
authorization_response=str(request.url),
|
||||
redirect_uri=get_oauth_redirect_uri_for_current_mode(),
|
||||
session_id=mcp_session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth callback: Successfully authenticated user: {verified_user_id}."
|
||||
)
|
||||
|
||||
try:
|
||||
store = get_oauth21_session_store()
|
||||
|
||||
store.store_session(
|
||||
user_email=verified_user_id,
|
||||
access_token=credentials.token,
|
||||
refresh_token=credentials.refresh_token,
|
||||
token_uri=credentials.token_uri,
|
||||
client_id=credentials.client_id,
|
||||
client_secret=credentials.client_secret,
|
||||
scopes=credentials.scopes,
|
||||
expiry=credentials.expiry,
|
||||
session_id=f"google-{state}",
|
||||
mcp_session_id=mcp_session_id,
|
||||
)
|
||||
logger.info(
|
||||
f"Stored Google credentials in OAuth 2.1 session store for {verified_user_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials in OAuth 2.1 store: {e}")
|
||||
|
||||
return create_success_response(verified_user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing OAuth callback: {str(e)}", exc_info=True)
|
||||
return create_server_error_response(str(e))
|
||||
|
||||
|
||||
@server.tool()
|
||||
async def start_google_auth(
|
||||
service_name: str, user_google_email: str = USER_GOOGLE_EMAIL
|
||||
) -> str:
|
||||
"""
|
||||
Manually initiate Google OAuth authentication flow.
|
||||
|
||||
NOTE: This is a legacy OAuth 2.0 tool and is disabled when OAuth 2.1 is enabled.
|
||||
The authentication system automatically handles credential checks and prompts for
|
||||
authentication when needed. Only use this tool if:
|
||||
1. You need to re-authenticate with different credentials
|
||||
2. You want to proactively authenticate before using other tools
|
||||
3. The automatic authentication flow failed and you need to retry
|
||||
|
||||
In most cases, simply try calling the Google Workspace tool you need - it will
|
||||
automatically handle authentication if required.
|
||||
"""
|
||||
if is_oauth21_enabled():
|
||||
if is_external_oauth21_provider():
|
||||
return (
|
||||
"start_google_auth is disabled when OAuth 2.1 is enabled. "
|
||||
"Provide a valid OAuth 2.1 bearer token in the Authorization header "
|
||||
"and retry the original tool."
|
||||
)
|
||||
return (
|
||||
"start_google_auth is disabled when OAuth 2.1 is enabled. "
|
||||
"Authenticate through your MCP client's OAuth 2.1 flow and retry the "
|
||||
"original tool."
|
||||
)
|
||||
|
||||
if not user_google_email:
|
||||
raise ValueError("user_google_email must be provided.")
|
||||
|
||||
error_message = check_client_secrets()
|
||||
if error_message:
|
||||
return f"**Authentication Error:** {error_message}"
|
||||
|
||||
try:
|
||||
auth_message = await start_auth_flow(
|
||||
user_google_email=user_google_email,
|
||||
service_name=service_name,
|
||||
redirect_uri=get_oauth_redirect_uri_for_current_mode(),
|
||||
)
|
||||
return auth_message
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start Google authentication flow: {e}", exc_info=True)
|
||||
return f"**Error:** An unexpected error occurred: {e}"
|
||||
211
core/tool_registry.py
Normal file
211
core/tool_registry.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Tool Registry for Conditional Tool Registration
|
||||
|
||||
This module provides a registry system that allows tools to be conditionally registered
|
||||
based on tier configuration, replacing direct @server.tool() decorators.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Set, Optional, Callable
|
||||
|
||||
from auth.oauth_config import is_oauth21_enabled
|
||||
from auth.permissions import is_permissions_mode, get_allowed_scopes_set
|
||||
from auth.scopes import is_read_only_mode, get_all_read_only_scopes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global registry of enabled tools
|
||||
_enabled_tools: Optional[Set[str]] = None
|
||||
|
||||
|
||||
def set_enabled_tools(tool_names: Optional[Set[str]]):
|
||||
"""Set the globally enabled tools."""
|
||||
global _enabled_tools
|
||||
_enabled_tools = tool_names
|
||||
|
||||
|
||||
def get_enabled_tools() -> Optional[Set[str]]:
|
||||
"""Get the set of enabled tools, or None if all tools are enabled."""
|
||||
return _enabled_tools
|
||||
|
||||
|
||||
def is_tool_enabled(tool_name: str) -> bool:
|
||||
"""Check if a specific tool is enabled."""
|
||||
if _enabled_tools is None:
|
||||
return True # All tools enabled by default
|
||||
return tool_name in _enabled_tools
|
||||
|
||||
|
||||
def conditional_tool(server, tool_name: str):
|
||||
"""
|
||||
Decorator that conditionally registers a tool based on the enabled tools set.
|
||||
|
||||
Args:
|
||||
server: The FastMCP server instance
|
||||
tool_name: The name of the tool to register
|
||||
|
||||
Returns:
|
||||
Either the registered tool decorator or a no-op decorator
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
if is_tool_enabled(tool_name):
|
||||
logger.debug(f"Registering tool: {tool_name}")
|
||||
return server.tool()(func)
|
||||
else:
|
||||
logger.debug(f"Skipping tool registration: {tool_name}")
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def wrap_server_tool_method(server):
|
||||
"""
|
||||
Track tool registrations and filter them post-registration.
|
||||
"""
|
||||
original_tool = server.tool
|
||||
server._tracked_tools = []
|
||||
|
||||
def tracking_tool(*args, **kwargs):
|
||||
original_decorator = original_tool(*args, **kwargs)
|
||||
|
||||
def wrapper_decorator(func: Callable) -> Callable:
|
||||
tool_name = func.__name__
|
||||
server._tracked_tools.append(tool_name)
|
||||
# Always apply the original decorator to register the tool
|
||||
return original_decorator(func)
|
||||
|
||||
return wrapper_decorator
|
||||
|
||||
server.tool = tracking_tool
|
||||
|
||||
|
||||
def get_tool_components(server) -> dict:
|
||||
"""Get tool components dict from server's local_provider.
|
||||
|
||||
Returns a dict mapping tool_name -> tool_object for introspection.
|
||||
|
||||
Note: Uses local_provider._components because the public list_tools()
|
||||
is async-only, and callers (startup filtering, CLI) run synchronously.
|
||||
"""
|
||||
lp = getattr(server, "local_provider", None)
|
||||
if lp is None:
|
||||
return {}
|
||||
components = getattr(lp, "_components", {})
|
||||
tools = {}
|
||||
for key, component in components.items():
|
||||
if key.startswith("tool:"):
|
||||
# Keys are like "tool:name@version", extract the name
|
||||
name = key.split(":", 1)[1].rsplit("@", 1)[0]
|
||||
tools[name] = component
|
||||
return tools
|
||||
|
||||
|
||||
def filter_server_tools(server):
|
||||
"""Remove disabled tools from the server after registration."""
|
||||
enabled_tools = get_enabled_tools()
|
||||
oauth21_enabled = is_oauth21_enabled()
|
||||
permissions_mode = is_permissions_mode()
|
||||
if (
|
||||
enabled_tools is None
|
||||
and not oauth21_enabled
|
||||
and not is_read_only_mode()
|
||||
and not permissions_mode
|
||||
):
|
||||
return
|
||||
|
||||
tools_removed = 0
|
||||
tool_components = get_tool_components(server)
|
||||
|
||||
read_only_mode = is_read_only_mode()
|
||||
allowed_scopes = set(get_all_read_only_scopes()) if read_only_mode else None
|
||||
|
||||
tools_to_remove = set()
|
||||
|
||||
# 1. Tier filtering
|
||||
if enabled_tools is not None:
|
||||
for tool_name in tool_components:
|
||||
if not is_tool_enabled(tool_name):
|
||||
tools_to_remove.add(tool_name)
|
||||
|
||||
# 2. OAuth 2.1 filtering
|
||||
if oauth21_enabled and "start_google_auth" in tool_components:
|
||||
tools_to_remove.add("start_google_auth")
|
||||
logger.info("OAuth 2.1 enabled: disabling start_google_auth tool")
|
||||
|
||||
# 3. Read-only mode filtering (skipped when granular permissions are active)
|
||||
if read_only_mode and not permissions_mode:
|
||||
for tool_name, tool_obj in tool_components.items():
|
||||
if tool_name in tools_to_remove:
|
||||
continue
|
||||
|
||||
# Check if tool has required scopes attached (from @require_google_service)
|
||||
func_to_check = tool_obj
|
||||
if hasattr(tool_obj, "fn"):
|
||||
func_to_check = tool_obj.fn
|
||||
|
||||
required_scopes = getattr(func_to_check, "_required_google_scopes", [])
|
||||
|
||||
if required_scopes:
|
||||
# If ANY required scope is not in the allowed read-only scopes, disable the tool
|
||||
if not all(scope in allowed_scopes for scope in required_scopes):
|
||||
logger.info(
|
||||
f"Read-only mode: Disabling tool '{tool_name}' (requires write scopes: {required_scopes})"
|
||||
)
|
||||
tools_to_remove.add(tool_name)
|
||||
|
||||
# 4. Granular permissions filtering
|
||||
# No scope hierarchy expansion here — permission levels are already cumulative
|
||||
# and explicitly define allowed scopes. Hierarchy expansion would defeat the
|
||||
# purpose (e.g. gmail.modify in the hierarchy covers gmail.send, but the
|
||||
# "organize" permission level intentionally excludes gmail.send).
|
||||
if permissions_mode:
|
||||
perm_allowed = get_allowed_scopes_set() or set()
|
||||
|
||||
for tool_name, tool_obj in tool_components.items():
|
||||
if tool_name in tools_to_remove:
|
||||
continue
|
||||
|
||||
func_to_check = tool_obj
|
||||
if hasattr(tool_obj, "fn"):
|
||||
func_to_check = tool_obj.fn
|
||||
|
||||
required_scopes = getattr(func_to_check, "_required_google_scopes", [])
|
||||
if required_scopes:
|
||||
if not all(scope in perm_allowed for scope in required_scopes):
|
||||
logger.info(
|
||||
"Permissions mode: Disabling tool '%s' (requires: %s)",
|
||||
tool_name,
|
||||
required_scopes,
|
||||
)
|
||||
tools_to_remove.add(tool_name)
|
||||
|
||||
for tool_name in tools_to_remove:
|
||||
try:
|
||||
server.local_provider.remove_tool(tool_name)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"Failed to remove tool '%s': remove_tool not available on server.local_provider",
|
||||
tool_name,
|
||||
)
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to remove tool '%s': %s",
|
||||
tool_name,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
tools_removed += 1
|
||||
|
||||
if tools_removed > 0:
|
||||
enabled_count = len(enabled_tools) if enabled_tools is not None else "all"
|
||||
if permissions_mode:
|
||||
mode = "Permissions"
|
||||
elif is_read_only_mode():
|
||||
mode = "Read-Only"
|
||||
else:
|
||||
mode = "Full"
|
||||
logger.info(
|
||||
f"Tool filtering: removed {tools_removed} tools, {enabled_count} enabled. Mode: {mode}"
|
||||
)
|
||||
196
core/tool_tier_loader.py
Normal file
196
core/tool_tier_loader.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Tool Tier Loader Module
|
||||
|
||||
This module provides functionality to load and resolve tool tiers from the YAML configuration.
|
||||
It integrates with the existing tool enablement workflow to support tiered tool loading.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set, Literal, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TierLevel = Literal["core", "extended", "complete"]
|
||||
|
||||
|
||||
class ToolTierLoader:
|
||||
"""Loads and manages tool tiers from configuration."""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""
|
||||
Initialize the tool tier loader.
|
||||
|
||||
Args:
|
||||
config_path: Path to the tool_tiers.yaml file. If None, uses default location.
|
||||
"""
|
||||
if config_path is None:
|
||||
# Default to core/tool_tiers.yaml relative to this file
|
||||
config_path = Path(__file__).parent / "tool_tiers.yaml"
|
||||
|
||||
self.config_path = Path(config_path)
|
||||
self._tiers_config: Optional[Dict] = None
|
||||
|
||||
def _load_config(self) -> Dict:
|
||||
"""Load the tool tiers configuration from YAML file."""
|
||||
if self._tiers_config is not None:
|
||||
return self._tiers_config
|
||||
|
||||
if not self.config_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Tool tiers configuration not found: {self.config_path}"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
self._tiers_config = yaml.safe_load(f)
|
||||
logger.info(f"Loaded tool tiers configuration from {self.config_path}")
|
||||
return self._tiers_config
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML in tool tiers configuration: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load tool tiers configuration: {e}")
|
||||
|
||||
def get_available_services(self) -> List[str]:
|
||||
"""Get list of all available services defined in the configuration."""
|
||||
config = self._load_config()
|
||||
return list(config.keys())
|
||||
|
||||
def get_tools_for_tier(
|
||||
self, tier: TierLevel, services: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get all tools for a specific tier level.
|
||||
|
||||
Args:
|
||||
tier: The tier level (core, extended, complete)
|
||||
services: Optional list of services to filter by. If None, includes all services.
|
||||
|
||||
Returns:
|
||||
List of tool names for the specified tier level
|
||||
"""
|
||||
config = self._load_config()
|
||||
tools = []
|
||||
|
||||
# If no services specified, use all available services
|
||||
if services is None:
|
||||
services = self.get_available_services()
|
||||
|
||||
for service in services:
|
||||
if service not in config:
|
||||
logger.warning(
|
||||
f"Service '{service}' not found in tool tiers configuration"
|
||||
)
|
||||
continue
|
||||
|
||||
service_config = config[service]
|
||||
if tier not in service_config:
|
||||
logger.debug(f"Tier '{tier}' not defined for service '{service}'")
|
||||
continue
|
||||
|
||||
tier_tools = service_config[tier]
|
||||
if tier_tools: # Handle empty lists
|
||||
tools.extend(tier_tools)
|
||||
|
||||
return tools
|
||||
|
||||
def get_tools_up_to_tier(
|
||||
self, tier: TierLevel, services: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get all tools up to and including the specified tier level.
|
||||
|
||||
Args:
|
||||
tier: The maximum tier level to include
|
||||
services: Optional list of services to filter by. If None, includes all services.
|
||||
|
||||
Returns:
|
||||
List of tool names up to the specified tier level
|
||||
"""
|
||||
tier_order = ["core", "extended", "complete"]
|
||||
max_tier_index = tier_order.index(tier)
|
||||
|
||||
tools = []
|
||||
for i in range(max_tier_index + 1):
|
||||
current_tier = tier_order[i]
|
||||
tools.extend(self.get_tools_for_tier(current_tier, services))
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_tools = []
|
||||
for tool in tools:
|
||||
if tool not in seen:
|
||||
seen.add(tool)
|
||||
unique_tools.append(tool)
|
||||
|
||||
return unique_tools
|
||||
|
||||
def get_services_for_tools(self, tool_names: List[str]) -> Set[str]:
|
||||
"""
|
||||
Get the service names that provide the specified tools.
|
||||
|
||||
Args:
|
||||
tool_names: List of tool names to lookup
|
||||
|
||||
Returns:
|
||||
Set of service names that provide any of the specified tools
|
||||
"""
|
||||
config = self._load_config()
|
||||
services = set()
|
||||
|
||||
for service, service_config in config.items():
|
||||
for tier_name, tier_tools in service_config.items():
|
||||
if tier_tools and any(tool in tier_tools for tool in tool_names):
|
||||
services.add(service)
|
||||
break
|
||||
|
||||
return services
|
||||
|
||||
|
||||
def get_tools_for_tier(
|
||||
tier: TierLevel, services: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Convenience function to get tools for a specific tier.
|
||||
|
||||
Args:
|
||||
tier: The tier level (core, extended, complete)
|
||||
services: Optional list of services to filter by
|
||||
|
||||
Returns:
|
||||
List of tool names for the specified tier level
|
||||
"""
|
||||
loader = ToolTierLoader()
|
||||
return loader.get_tools_up_to_tier(tier, services)
|
||||
|
||||
|
||||
def resolve_tools_from_tier(
|
||||
tier: TierLevel, services: Optional[List[str]] = None
|
||||
) -> tuple[List[str], List[str]]:
|
||||
"""
|
||||
Resolve tool names and service names for the specified tier.
|
||||
|
||||
Args:
|
||||
tier: The tier level (core, extended, complete)
|
||||
services: Optional list of services to filter by
|
||||
|
||||
Returns:
|
||||
Tuple of (tool_names, service_names) where:
|
||||
- tool_names: List of specific tool names for the tier
|
||||
- service_names: List of service names that should be imported
|
||||
"""
|
||||
loader = ToolTierLoader()
|
||||
|
||||
# Get all tools for the tier
|
||||
tools = loader.get_tools_up_to_tier(tier, services)
|
||||
|
||||
# Map back to service names
|
||||
service_names = loader.get_services_for_tools(tools)
|
||||
|
||||
logger.info(
|
||||
f"Tier '{tier}' resolved to {len(tools)} tools across {len(service_names)} services: {sorted(service_names)}"
|
||||
)
|
||||
|
||||
return tools, sorted(service_names)
|
||||
172
core/tool_tiers.yaml
Normal file
172
core/tool_tiers.yaml
Normal file
@@ -0,0 +1,172 @@
|
||||
gmail:
|
||||
core:
|
||||
- search_gmail_messages
|
||||
- get_gmail_message_content
|
||||
- get_gmail_messages_content_batch
|
||||
- send_gmail_message
|
||||
|
||||
extended:
|
||||
- get_gmail_attachment_content
|
||||
- get_gmail_thread_content
|
||||
- modify_gmail_message_labels
|
||||
- list_gmail_labels
|
||||
- manage_gmail_label
|
||||
- draft_gmail_message
|
||||
- list_gmail_filters
|
||||
- manage_gmail_filter
|
||||
|
||||
complete:
|
||||
- get_gmail_threads_content_batch
|
||||
- batch_modify_gmail_message_labels
|
||||
- start_google_auth
|
||||
|
||||
drive:
|
||||
core:
|
||||
- search_drive_files
|
||||
- get_drive_file_content
|
||||
- get_drive_file_download_url
|
||||
- create_drive_file
|
||||
- create_drive_folder
|
||||
- import_to_google_doc
|
||||
- get_drive_shareable_link
|
||||
extended:
|
||||
- list_drive_items
|
||||
- copy_drive_file
|
||||
- update_drive_file
|
||||
- manage_drive_access
|
||||
- set_drive_file_permissions
|
||||
complete:
|
||||
- get_drive_file_permissions
|
||||
- check_drive_file_public_access
|
||||
|
||||
calendar:
|
||||
core:
|
||||
- list_calendars
|
||||
- get_events
|
||||
- manage_event
|
||||
extended:
|
||||
- query_freebusy
|
||||
complete: []
|
||||
|
||||
docs:
|
||||
core:
|
||||
- get_doc_content
|
||||
- create_doc
|
||||
- modify_doc_text
|
||||
extended:
|
||||
- export_doc_to_pdf
|
||||
- search_docs
|
||||
- find_and_replace_doc
|
||||
- list_docs_in_folder
|
||||
- insert_doc_elements
|
||||
- update_paragraph_style
|
||||
- get_doc_as_markdown
|
||||
complete:
|
||||
- insert_doc_image
|
||||
- update_doc_headers_footers
|
||||
- batch_update_doc
|
||||
- inspect_doc_structure
|
||||
- create_table_with_data
|
||||
- debug_table_structure
|
||||
- list_document_comments
|
||||
- manage_document_comment
|
||||
|
||||
sheets:
|
||||
core:
|
||||
- create_spreadsheet
|
||||
- read_sheet_values
|
||||
- modify_sheet_values
|
||||
extended:
|
||||
- list_spreadsheets
|
||||
- get_spreadsheet_info
|
||||
- format_sheet_range
|
||||
complete:
|
||||
- create_sheet
|
||||
- list_spreadsheet_comments
|
||||
- manage_spreadsheet_comment
|
||||
- manage_conditional_formatting
|
||||
|
||||
chat:
|
||||
core:
|
||||
- send_message
|
||||
- get_messages
|
||||
- search_messages
|
||||
- create_reaction
|
||||
extended:
|
||||
- list_spaces
|
||||
- download_chat_attachment
|
||||
complete: []
|
||||
|
||||
forms:
|
||||
core:
|
||||
- create_form
|
||||
- get_form
|
||||
extended:
|
||||
- list_form_responses
|
||||
complete:
|
||||
- set_publish_settings
|
||||
- get_form_response
|
||||
- batch_update_form
|
||||
|
||||
slides:
|
||||
core:
|
||||
- create_presentation
|
||||
- get_presentation
|
||||
extended:
|
||||
- batch_update_presentation
|
||||
- get_page
|
||||
- get_page_thumbnail
|
||||
complete:
|
||||
- list_presentation_comments
|
||||
- manage_presentation_comment
|
||||
|
||||
tasks:
|
||||
core:
|
||||
- get_task
|
||||
- list_tasks
|
||||
- manage_task
|
||||
extended: []
|
||||
complete:
|
||||
- list_task_lists
|
||||
- get_task_list
|
||||
- manage_task_list
|
||||
|
||||
contacts:
|
||||
core:
|
||||
- search_contacts
|
||||
- get_contact
|
||||
- list_contacts
|
||||
- manage_contact
|
||||
extended:
|
||||
- list_contact_groups
|
||||
- get_contact_group
|
||||
complete:
|
||||
- manage_contacts_batch
|
||||
- manage_contact_group
|
||||
|
||||
search:
|
||||
core:
|
||||
- search_custom
|
||||
extended: []
|
||||
complete:
|
||||
- get_search_engine_info
|
||||
|
||||
appscript:
|
||||
core:
|
||||
- list_script_projects
|
||||
- get_script_project
|
||||
- get_script_content
|
||||
- create_script_project
|
||||
- update_script_content
|
||||
- run_script_function
|
||||
- generate_trigger_code
|
||||
extended:
|
||||
- manage_deployment
|
||||
- list_deployments
|
||||
- delete_script_project
|
||||
- list_versions
|
||||
- create_version
|
||||
- get_version
|
||||
- list_script_processes
|
||||
- get_script_metrics
|
||||
complete: []
|
||||
493
core/utils.py
Normal file
493
core/utils.py
Normal file
@@ -0,0 +1,493 @@
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
import ssl
|
||||
import asyncio
|
||||
import functools
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from defusedxml import ElementTree as ET
|
||||
|
||||
from googleapiclient.errors import HttpError
|
||||
from .api_enablement import get_api_enablement_message
|
||||
from auth.google_auth import GoogleAuthenticationError
|
||||
from auth.oauth_config import is_oauth21_enabled, is_external_oauth21_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransientNetworkError(Exception):
|
||||
"""Custom exception for transient network errors after retries."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UserInputError(Exception):
|
||||
"""Raised for user-facing input/validation errors that shouldn't be retried."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Directories from which local file reads are allowed.
|
||||
# The user's home directory is the default safe base.
|
||||
# Override via ALLOWED_FILE_DIRS env var (os.pathsep-separated paths).
|
||||
_ALLOWED_FILE_DIRS_ENV = "ALLOWED_FILE_DIRS"
|
||||
|
||||
|
||||
def _get_allowed_file_dirs() -> list[Path]:
|
||||
"""Return the list of directories from which local file access is permitted."""
|
||||
env_val = os.environ.get(_ALLOWED_FILE_DIRS_ENV)
|
||||
if env_val:
|
||||
return [
|
||||
Path(p).expanduser().resolve()
|
||||
for p in env_val.split(os.pathsep)
|
||||
if p.strip()
|
||||
]
|
||||
home = Path.home()
|
||||
return [home] if home else []
|
||||
|
||||
|
||||
def validate_file_path(file_path: str) -> Path:
|
||||
"""
|
||||
Validate that a file path is safe to read from the server filesystem.
|
||||
|
||||
Resolves the path canonically (following symlinks), then verifies it falls
|
||||
within one of the allowed base directories. Rejects paths to sensitive
|
||||
system locations regardless of allowlist.
|
||||
|
||||
Args:
|
||||
file_path: The raw file path string to validate.
|
||||
|
||||
Returns:
|
||||
Path: The resolved, validated Path object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path is outside allowed directories or targets
|
||||
a sensitive location.
|
||||
"""
|
||||
resolved = Path(file_path).resolve()
|
||||
|
||||
if not resolved.exists():
|
||||
raise FileNotFoundError(f"Path does not exist: {resolved}")
|
||||
|
||||
# Block sensitive file patterns regardless of allowlist
|
||||
resolved_str = str(resolved)
|
||||
file_name = resolved.name.lower()
|
||||
|
||||
# Block .env files and variants (.env, .env.local, .env.production, etc.)
|
||||
if file_name == ".env" or file_name.startswith(".env."):
|
||||
raise ValueError(
|
||||
f"Access to '{resolved_str}' is not allowed: "
|
||||
".env files may contain secrets and cannot be read, uploaded, or attached."
|
||||
)
|
||||
|
||||
# Block well-known sensitive system paths (including macOS /private variants)
|
||||
sensitive_prefixes = (
|
||||
"/proc",
|
||||
"/sys",
|
||||
"/dev",
|
||||
"/etc/shadow",
|
||||
"/etc/passwd",
|
||||
"/private/etc/shadow",
|
||||
"/private/etc/passwd",
|
||||
)
|
||||
for prefix in sensitive_prefixes:
|
||||
if resolved_str == prefix or resolved_str.startswith(prefix + "/"):
|
||||
raise ValueError(
|
||||
f"Access to '{resolved_str}' is not allowed: "
|
||||
"path is in a restricted system location."
|
||||
)
|
||||
|
||||
# Block sensitive directories that commonly contain credentials/keys
|
||||
sensitive_dirs = (
|
||||
".ssh",
|
||||
".aws",
|
||||
".kube",
|
||||
".gnupg",
|
||||
".config/gcloud",
|
||||
)
|
||||
for sensitive_dir in sensitive_dirs:
|
||||
home = Path.home()
|
||||
blocked = home / sensitive_dir
|
||||
if resolved == blocked or str(resolved).startswith(str(blocked) + "/"):
|
||||
raise ValueError(
|
||||
f"Access to '{resolved_str}' is not allowed: "
|
||||
"path is in a directory that commonly contains secrets or credentials."
|
||||
)
|
||||
|
||||
# Block other credential/secret file patterns
|
||||
sensitive_names = {
|
||||
".credentials",
|
||||
".credentials.json",
|
||||
"credentials.json",
|
||||
"client_secret.json",
|
||||
"client_secrets.json",
|
||||
"service_account.json",
|
||||
"service-account.json",
|
||||
".npmrc",
|
||||
".pypirc",
|
||||
".netrc",
|
||||
".git-credentials",
|
||||
".docker/config.json",
|
||||
}
|
||||
if file_name in sensitive_names:
|
||||
raise ValueError(
|
||||
f"Access to '{resolved_str}' is not allowed: "
|
||||
"this file commonly contains secrets or credentials."
|
||||
)
|
||||
|
||||
allowed_dirs = _get_allowed_file_dirs()
|
||||
if not allowed_dirs:
|
||||
raise ValueError(
|
||||
"No allowed file directories configured. "
|
||||
"Set the ALLOWED_FILE_DIRS environment variable or ensure a home directory exists."
|
||||
)
|
||||
|
||||
for allowed in allowed_dirs:
|
||||
try:
|
||||
resolved.relative_to(allowed)
|
||||
return resolved
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Access to '{resolved_str}' is not allowed: "
|
||||
f"path is outside permitted directories ({', '.join(str(d) for d in allowed_dirs)}). "
|
||||
"Set ALLOWED_FILE_DIRS to adjust."
|
||||
)
|
||||
|
||||
|
||||
def check_credentials_directory_permissions(credentials_dir: str = None) -> None:
|
||||
"""
|
||||
Check if the service has appropriate permissions to create and write to the .credentials directory.
|
||||
|
||||
Args:
|
||||
credentials_dir: Path to the credentials directory (default: uses get_default_credentials_dir())
|
||||
|
||||
Raises:
|
||||
PermissionError: If the service lacks necessary permissions
|
||||
OSError: If there are other file system issues
|
||||
"""
|
||||
if credentials_dir is None:
|
||||
from auth.google_auth import get_default_credentials_dir
|
||||
|
||||
credentials_dir = get_default_credentials_dir()
|
||||
|
||||
try:
|
||||
# Check if directory exists
|
||||
if os.path.exists(credentials_dir):
|
||||
# Directory exists, check if we can write to it
|
||||
test_file = os.path.join(credentials_dir, ".permission_test")
|
||||
try:
|
||||
with open(test_file, "w") as f:
|
||||
f.write("test")
|
||||
os.remove(test_file)
|
||||
logger.info(
|
||||
f"Credentials directory permissions check passed: {os.path.abspath(credentials_dir)}"
|
||||
)
|
||||
except (PermissionError, OSError) as e:
|
||||
raise PermissionError(
|
||||
f"Cannot write to existing credentials directory '{os.path.abspath(credentials_dir)}': {e}"
|
||||
)
|
||||
else:
|
||||
# Directory doesn't exist, try to create it and its parent directories
|
||||
try:
|
||||
os.makedirs(credentials_dir, exist_ok=True)
|
||||
# Test writing to the new directory
|
||||
test_file = os.path.join(credentials_dir, ".permission_test")
|
||||
with open(test_file, "w") as f:
|
||||
f.write("test")
|
||||
os.remove(test_file)
|
||||
logger.info(
|
||||
f"Created credentials directory with proper permissions: {os.path.abspath(credentials_dir)}"
|
||||
)
|
||||
except (PermissionError, OSError) as e:
|
||||
# Clean up if we created the directory but can't write to it
|
||||
try:
|
||||
if os.path.exists(credentials_dir):
|
||||
os.rmdir(credentials_dir)
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
raise PermissionError(
|
||||
f"Cannot create or write to credentials directory '{os.path.abspath(credentials_dir)}': {e}"
|
||||
)
|
||||
|
||||
except PermissionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise OSError(
|
||||
f"Unexpected error checking credentials directory permissions: {e}"
|
||||
)
|
||||
|
||||
|
||||
def extract_office_xml_text(file_bytes: bytes, mime_type: str) -> Optional[str]:
|
||||
"""
|
||||
Very light-weight XML scraper for Word, Excel, PowerPoint files.
|
||||
Returns plain-text if something readable is found, else None.
|
||||
Uses zipfile + defusedxml.ElementTree.
|
||||
"""
|
||||
shared_strings: List[str] = []
|
||||
ns_excel_main = "http://schemas.openxmlformats.org/spreadsheetml/2006/main"
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_bytes)) as zf:
|
||||
targets: List[str] = []
|
||||
# Map MIME → iterable of XML files to inspect
|
||||
if (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
targets = ["word/document.xml"]
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
):
|
||||
targets = [n for n in zf.namelist() if n.startswith("ppt/slides/slide")]
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
targets = [
|
||||
n
|
||||
for n in zf.namelist()
|
||||
if n.startswith("xl/worksheets/sheet") and "drawing" not in n
|
||||
]
|
||||
# Attempt to parse sharedStrings.xml for Excel files
|
||||
try:
|
||||
shared_strings_xml = zf.read("xl/sharedStrings.xml")
|
||||
shared_strings_root = ET.fromstring(shared_strings_xml)
|
||||
for si_element in shared_strings_root.findall(
|
||||
f"{{{ns_excel_main}}}si"
|
||||
):
|
||||
text_parts = []
|
||||
# Find all <t> elements, simple or within <r> runs, and concatenate their text
|
||||
for t_element in si_element.findall(f".//{{{ns_excel_main}}}t"):
|
||||
if t_element.text:
|
||||
text_parts.append(t_element.text)
|
||||
shared_strings.append("".join(text_parts))
|
||||
except KeyError:
|
||||
logger.info(
|
||||
"No sharedStrings.xml found in Excel file (this is optional)."
|
||||
)
|
||||
except ET.ParseError as e:
|
||||
logger.error(f"Error parsing sharedStrings.xml: {e}")
|
||||
except (
|
||||
Exception
|
||||
) as e: # Catch any other unexpected error during sharedStrings parsing
|
||||
logger.error(
|
||||
f"Unexpected error processing sharedStrings.xml: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
pieces: List[str] = []
|
||||
for member in targets:
|
||||
try:
|
||||
xml_content = zf.read(member)
|
||||
xml_root = ET.fromstring(xml_content)
|
||||
member_texts: List[str] = []
|
||||
|
||||
if (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
for cell_element in xml_root.findall(
|
||||
f".//{{{ns_excel_main}}}c"
|
||||
): # Find all <c> elements
|
||||
value_element = cell_element.find(
|
||||
f"{{{ns_excel_main}}}v"
|
||||
) # Find <v> under <c>
|
||||
|
||||
# Skip if cell has no value element or value element has no text
|
||||
if value_element is None or value_element.text is None:
|
||||
continue
|
||||
|
||||
cell_type = cell_element.get("t")
|
||||
if cell_type == "s": # Shared string
|
||||
try:
|
||||
ss_idx = int(value_element.text)
|
||||
if 0 <= ss_idx < len(shared_strings):
|
||||
member_texts.append(shared_strings[ss_idx])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid shared string index {ss_idx} in {member}. Max index: {len(shared_strings) - 1}"
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Non-integer shared string index: '{value_element.text}' in {member}."
|
||||
)
|
||||
else: # Direct value (number, boolean, inline string if not 's')
|
||||
member_texts.append(value_element.text)
|
||||
else: # Word or PowerPoint
|
||||
for elem in xml_root.iter():
|
||||
# For Word: <w:t> where w is "http://schemas.openxmlformats.org/wordprocessingml/2006/main"
|
||||
# For PowerPoint: <a:t> where a is "http://schemas.openxmlformats.org/drawingml/2006/main"
|
||||
if (
|
||||
elem.tag.endswith("}t") and elem.text
|
||||
): # Check for any namespaced tag ending with 't'
|
||||
cleaned_text = elem.text.strip()
|
||||
if (
|
||||
cleaned_text
|
||||
): # Add only if there's non-whitespace text
|
||||
member_texts.append(cleaned_text)
|
||||
|
||||
if member_texts:
|
||||
pieces.append(
|
||||
" ".join(member_texts)
|
||||
) # Join texts from one member with spaces
|
||||
|
||||
except ET.ParseError as e:
|
||||
logger.warning(
|
||||
f"Could not parse XML in member '{member}' for {mime_type} file: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing member '{member}' for {mime_type}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# continue processing other members
|
||||
|
||||
if not pieces: # If no text was extracted at all
|
||||
return None
|
||||
|
||||
# Join content from different members (sheets/slides) with double newlines for separation
|
||||
text = "\n\n".join(pieces).strip()
|
||||
return text or None # Ensure None is returned if text is empty after strip
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
logger.warning(f"File is not a valid ZIP archive (mime_type: {mime_type}).")
|
||||
return None
|
||||
except (
|
||||
ET.ParseError
|
||||
) as e: # Catch parsing errors at the top level if zipfile itself is XML-like
|
||||
logger.error(f"XML parsing error at a high level for {mime_type}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to extract office XML text for {mime_type}: {e}", exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def handle_http_errors(
|
||||
tool_name: str, is_read_only: bool = False, service_type: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
A decorator to handle Google API HttpErrors and transient SSL errors in a standardized way.
|
||||
|
||||
It wraps a tool function, catches HttpError, logs a detailed error message,
|
||||
and raises a generic Exception with a user-friendly message.
|
||||
|
||||
If is_read_only is True, it will also catch ssl.SSLError and retry with
|
||||
exponential backoff. After exhausting retries, it raises a TransientNetworkError.
|
||||
|
||||
Args:
|
||||
tool_name (str): The name of the tool being decorated (e.g., 'list_calendars').
|
||||
is_read_only (bool): If True, the operation is considered safe to retry on
|
||||
transient network errors. Defaults to False.
|
||||
service_type (str): Optional. The Google service type (e.g., 'calendar', 'gmail').
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
max_retries = 3
|
||||
base_delay = 1
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except ssl.SSLError as e:
|
||||
if is_read_only and attempt < max_retries - 1:
|
||||
delay = base_delay * (2**attempt)
|
||||
logger.warning(
|
||||
f"SSL error in {tool_name} on attempt {attempt + 1}: {e}. Retrying in {delay} seconds..."
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"SSL error in {tool_name} on final attempt: {e}. Raising exception."
|
||||
)
|
||||
raise TransientNetworkError(
|
||||
f"A transient SSL error occurred in '{tool_name}' after {max_retries} attempts. "
|
||||
"This is likely a temporary network or certificate issue. Please try again shortly."
|
||||
) from e
|
||||
except UserInputError as e:
|
||||
message = f"Input error in {tool_name}: {e}"
|
||||
logger.warning(message)
|
||||
raise e
|
||||
except HttpError as error:
|
||||
user_google_email = kwargs.get("user_google_email", "N/A")
|
||||
error_details = str(error)
|
||||
|
||||
# Check if this is an API not enabled error
|
||||
if (
|
||||
error.resp.status == 403
|
||||
and "accessNotConfigured" in error_details
|
||||
):
|
||||
enablement_msg = get_api_enablement_message(
|
||||
error_details, service_type
|
||||
)
|
||||
|
||||
if enablement_msg:
|
||||
message = (
|
||||
f"API error in {tool_name}: {enablement_msg}\n\n"
|
||||
f"User: {user_google_email}"
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"API error in {tool_name}: {error}. "
|
||||
f"The required API is not enabled for your project. "
|
||||
f"Please check the Google Cloud Console to enable it."
|
||||
)
|
||||
elif error.resp.status in [401, 403]:
|
||||
# Authentication/authorization errors
|
||||
if is_oauth21_enabled():
|
||||
if is_external_oauth21_provider():
|
||||
auth_hint = (
|
||||
"LLM: Ask the user to provide a valid OAuth 2.1 "
|
||||
"bearer token in the Authorization header and retry."
|
||||
)
|
||||
else:
|
||||
auth_hint = (
|
||||
"LLM: Ask the user to authenticate via their MCP "
|
||||
"client's OAuth 2.1 flow and retry."
|
||||
)
|
||||
else:
|
||||
auth_hint = (
|
||||
"LLM: Try 'start_google_auth' with the user's email "
|
||||
"and the appropriate service_name."
|
||||
)
|
||||
message = (
|
||||
f"API error in {tool_name}: {error}. "
|
||||
f"You might need to re-authenticate for user '{user_google_email}'. "
|
||||
f"{auth_hint}"
|
||||
)
|
||||
else:
|
||||
# Other HTTP errors (400 Bad Request, etc.) - don't suggest re-auth
|
||||
message = f"API error in {tool_name}: {error}"
|
||||
|
||||
logger.error(f"API error in {tool_name}: {error}", exc_info=True)
|
||||
raise Exception(message) from error
|
||||
except TransientNetworkError:
|
||||
# Re-raise without wrapping to preserve the specific error type
|
||||
raise
|
||||
except GoogleAuthenticationError:
|
||||
# Re-raise authentication errors without wrapping
|
||||
raise
|
||||
except Exception as e:
|
||||
message = f"An unexpected error occurred in {tool_name}: {e}"
|
||||
logger.exception(message)
|
||||
raise Exception(message) from e
|
||||
|
||||
# Propagate _required_google_scopes if present (for tool filtering)
|
||||
if hasattr(func, "_required_google_scopes"):
|
||||
wrapper._required_google_scopes = func._required_google_scopes
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
Reference in New Issue
Block a user