apply ruff formatting

This commit is contained in:
Taylor Wilsdon
2025-12-13 13:49:28 -08:00
parent 1d80a24ca4
commit 6b8352a354
50 changed files with 4010 additions and 2842 deletions

View File

@@ -45,15 +45,17 @@ INTERNAL_SERVICE_TO_API: Dict[str, str] = {
}
def extract_api_info_from_error(error_details: str) -> Tuple[Optional[str], Optional[str]]:
def extract_api_info_from_error(
error_details: str,
) -> Tuple[Optional[str], Optional[str]]:
"""
Extract API service and project ID from error details.
Returns:
Tuple of (api_service, project_id) or (None, None) if not found
"""
api_pattern = r'https://console\.developers\.google\.com/apis/api/([^/]+)/overview'
project_pattern = r'project[=\s]+([a-zA-Z0-9-]+)'
api_pattern = r"https://console\.developers\.google\.com/apis/api/([^/]+)/overview"
project_pattern = r"project[=\s]+([a-zA-Z0-9-]+)"
api_match = re.search(api_pattern, error_details)
project_match = re.search(project_pattern, error_details)
@@ -64,7 +66,9 @@ def extract_api_info_from_error(error_details: str) -> Tuple[Optional[str], Opti
return api_service, project_id
def get_api_enablement_message(error_details: str, service_type: Optional[str] = None) -> str:
def get_api_enablement_message(
error_details: str, service_type: Optional[str] = None
) -> str:
"""
Generate a helpful error message with direct API enablement link.
@@ -88,7 +92,7 @@ def get_api_enablement_message(error_details: str, service_type: Optional[str] =
enable_link = API_ENABLEMENT_LINKS[api_service]
service_display_name = next(
(name for name, api in SERVICE_NAME_TO_API.items() if api == api_service),
api_service
api_service,
)
message = (
@@ -101,4 +105,4 @@ def get_api_enablement_message(error_details: str, service_type: Optional[str] =
return message
return ""
return ""

View File

@@ -77,7 +77,9 @@ class AttachmentStorage:
file_path = STORAGE_DIR / f"{file_id}{extension}"
try:
file_path.write_bytes(file_bytes)
logger.info(f"Saved attachment {file_id} ({len(file_bytes)} bytes) to {file_path}")
logger.info(
f"Saved attachment {file_id} ({len(file_bytes)} bytes) to {file_path}"
)
except Exception as e:
logger.error(f"Failed to save attachment to {file_path}: {e}")
raise
@@ -213,4 +215,3 @@ def get_attachment_url(file_id: str) -> str:
base_url = f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}"
return f"{base_url}/attachments/{file_id}"

View File

@@ -36,79 +36,136 @@ def create_comment_tools(app_name: str, file_id_param: str):
# Create functions without decorators first, then apply decorators with proper names
if file_id_param == "document_id":
@require_google_service("drive", "drive_read")
@handle_http_errors(read_func_name, service_type="drive")
async def read_comments(service, user_google_email: str, document_id: str) -> str:
async def read_comments(
service, user_google_email: str, document_id: str
) -> str:
"""Read all comments from a Google Document."""
return await _read_comments_impl(service, app_name, document_id)
@require_google_service("drive", "drive_file")
@handle_http_errors(create_func_name, service_type="drive")
async def create_comment(service, user_google_email: str, document_id: str, comment_content: str) -> str:
async def create_comment(
service, user_google_email: str, document_id: str, comment_content: str
) -> str:
"""Create a new comment on a Google Document."""
return await _create_comment_impl(service, app_name, document_id, comment_content)
return await _create_comment_impl(
service, app_name, document_id, comment_content
)
@require_google_service("drive", "drive_file")
@handle_http_errors(reply_func_name, service_type="drive")
async def reply_to_comment(service, user_google_email: str, document_id: str, comment_id: str, reply_content: str) -> str:
async def reply_to_comment(
service,
user_google_email: str,
document_id: str,
comment_id: str,
reply_content: str,
) -> str:
"""Reply to a specific comment in a Google Document."""
return await _reply_to_comment_impl(service, app_name, document_id, comment_id, reply_content)
return await _reply_to_comment_impl(
service, app_name, document_id, comment_id, reply_content
)
@require_google_service("drive", "drive_file")
@handle_http_errors(resolve_func_name, service_type="drive")
async def resolve_comment(service, user_google_email: str, document_id: str, comment_id: str) -> str:
async def resolve_comment(
service, user_google_email: str, document_id: str, comment_id: str
) -> str:
"""Resolve a comment in a Google Document."""
return await _resolve_comment_impl(service, app_name, document_id, comment_id)
return await _resolve_comment_impl(
service, app_name, document_id, comment_id
)
elif file_id_param == "spreadsheet_id":
@require_google_service("drive", "drive_read")
@handle_http_errors(read_func_name, service_type="drive")
async def read_comments(service, user_google_email: str, spreadsheet_id: str) -> str:
async def read_comments(
service, user_google_email: str, spreadsheet_id: str
) -> str:
"""Read all comments from a Google Spreadsheet."""
return await _read_comments_impl(service, app_name, spreadsheet_id)
@require_google_service("drive", "drive_file")
@handle_http_errors(create_func_name, service_type="drive")
async def create_comment(service, user_google_email: str, spreadsheet_id: str, comment_content: str) -> str:
async def create_comment(
service, user_google_email: str, spreadsheet_id: str, comment_content: str
) -> str:
"""Create a new comment on a Google Spreadsheet."""
return await _create_comment_impl(service, app_name, spreadsheet_id, comment_content)
return await _create_comment_impl(
service, app_name, spreadsheet_id, comment_content
)
@require_google_service("drive", "drive_file")
@handle_http_errors(reply_func_name, service_type="drive")
async def reply_to_comment(service, user_google_email: str, spreadsheet_id: str, comment_id: str, reply_content: str) -> str:
async def reply_to_comment(
service,
user_google_email: str,
spreadsheet_id: str,
comment_id: str,
reply_content: str,
) -> str:
"""Reply to a specific comment in a Google Spreadsheet."""
return await _reply_to_comment_impl(service, app_name, spreadsheet_id, comment_id, reply_content)
return await _reply_to_comment_impl(
service, app_name, spreadsheet_id, comment_id, reply_content
)
@require_google_service("drive", "drive_file")
@handle_http_errors(resolve_func_name, service_type="drive")
async def resolve_comment(service, user_google_email: str, spreadsheet_id: str, comment_id: str) -> str:
async def resolve_comment(
service, user_google_email: str, spreadsheet_id: str, comment_id: str
) -> str:
"""Resolve a comment in a Google Spreadsheet."""
return await _resolve_comment_impl(service, app_name, spreadsheet_id, comment_id)
return await _resolve_comment_impl(
service, app_name, spreadsheet_id, comment_id
)
elif file_id_param == "presentation_id":
@require_google_service("drive", "drive_read")
@handle_http_errors(read_func_name, service_type="drive")
async def read_comments(service, user_google_email: str, presentation_id: str) -> str:
async def read_comments(
service, user_google_email: str, presentation_id: str
) -> str:
"""Read all comments from a Google Presentation."""
return await _read_comments_impl(service, app_name, presentation_id)
@require_google_service("drive", "drive_file")
@handle_http_errors(create_func_name, service_type="drive")
async def create_comment(service, user_google_email: str, presentation_id: str, comment_content: str) -> str:
async def create_comment(
service, user_google_email: str, presentation_id: str, comment_content: str
) -> str:
"""Create a new comment on a Google Presentation."""
return await _create_comment_impl(service, app_name, presentation_id, comment_content)
return await _create_comment_impl(
service, app_name, presentation_id, comment_content
)
@require_google_service("drive", "drive_file")
@handle_http_errors(reply_func_name, service_type="drive")
async def reply_to_comment(service, user_google_email: str, presentation_id: str, comment_id: str, reply_content: str) -> str:
async def reply_to_comment(
service,
user_google_email: str,
presentation_id: str,
comment_id: str,
reply_content: str,
) -> str:
"""Reply to a specific comment in a Google Presentation."""
return await _reply_to_comment_impl(service, app_name, presentation_id, comment_id, reply_content)
return await _reply_to_comment_impl(
service, app_name, presentation_id, comment_id, reply_content
)
@require_google_service("drive", "drive_file")
@handle_http_errors(resolve_func_name, service_type="drive")
async def resolve_comment(service, user_google_email: str, presentation_id: str, comment_id: str) -> str:
async def resolve_comment(
service, user_google_email: str, presentation_id: str, comment_id: str
) -> str:
"""Resolve a comment in a Google Presentation."""
return await _resolve_comment_impl(service, app_name, presentation_id, comment_id)
return await _resolve_comment_impl(
service, app_name, presentation_id, comment_id
)
# Set the proper function names and register with server
read_comments.__name__ = read_func_name
@@ -123,10 +180,10 @@ def create_comment_tools(app_name: str, file_id_param: str):
server.tool()(resolve_comment)
return {
'read_comments': read_comments,
'create_comment': create_comment,
'reply_to_comment': reply_to_comment,
'resolve_comment': resolve_comment
"read_comments": read_comments,
"create_comment": create_comment,
"reply_to_comment": reply_to_comment,
"resolve_comment": resolve_comment,
}
@@ -135,13 +192,15 @@ async def _read_comments_impl(service, app_name: str, file_id: str) -> str:
logger.info(f"[read_{app_name}_comments] Reading comments for {app_name} {file_id}")
response = await asyncio.to_thread(
service.comments().list(
service.comments()
.list(
fileId=file_id,
fields="comments(id,content,author,createdTime,modifiedTime,resolved,replies(content,author,id,createdTime,modifiedTime))"
).execute
fields="comments(id,content,author,createdTime,modifiedTime,resolved,replies(content,author,id,createdTime,modifiedTime))",
)
.execute
)
comments = response.get('comments', [])
comments = response.get("comments", [])
if not comments:
return f"No comments found in {app_name} {file_id}"
@@ -149,11 +208,11 @@ async def _read_comments_impl(service, app_name: str, file_id: str) -> str:
output = [f"Found {len(comments)} comments in {app_name} {file_id}:\\n"]
for comment in comments:
author = comment.get('author', {}).get('displayName', 'Unknown')
content = comment.get('content', '')
created = comment.get('createdTime', '')
resolved = comment.get('resolved', False)
comment_id = comment.get('id', '')
author = comment.get("author", {}).get("displayName", "Unknown")
content = comment.get("content", "")
created = comment.get("createdTime", "")
resolved = comment.get("resolved", False)
comment_id = comment.get("id", "")
status = " [RESOLVED]" if resolved else ""
output.append(f"Comment ID: {comment_id}")
@@ -162,14 +221,14 @@ async def _read_comments_impl(service, app_name: str, file_id: str) -> str:
output.append(f"Content: {content}")
# Add replies if any
replies = comment.get('replies', [])
replies = comment.get("replies", [])
if replies:
output.append(f" Replies ({len(replies)}):")
for reply in replies:
reply_author = reply.get('author', {}).get('displayName', 'Unknown')
reply_content = reply.get('content', '')
reply_created = reply.get('createdTime', '')
reply_id = reply.get('id', '')
reply_author = reply.get("author", {}).get("displayName", "Unknown")
reply_content = reply.get("content", "")
reply_created = reply.get("createdTime", "")
reply_id = reply.get("id", "")
output.append(f" Reply ID: {reply_id}")
output.append(f" Author: {reply_author}")
output.append(f" Created: {reply_created}")
@@ -180,69 +239,82 @@ async def _read_comments_impl(service, app_name: str, file_id: str) -> str:
return "\\n".join(output)
async def _create_comment_impl(service, app_name: str, file_id: str, comment_content: str) -> str:
async def _create_comment_impl(
service, app_name: str, file_id: str, comment_content: str
) -> str:
"""Implementation for creating a comment on any Google Workspace file."""
logger.info(f"[create_{app_name}_comment] Creating comment in {app_name} {file_id}")
body = {"content": comment_content}
comment = await asyncio.to_thread(
service.comments().create(
service.comments()
.create(
fileId=file_id,
body=body,
fields="id,content,author,createdTime,modifiedTime"
).execute
fields="id,content,author,createdTime,modifiedTime",
)
.execute
)
comment_id = comment.get('id', '')
author = comment.get('author', {}).get('displayName', 'Unknown')
created = comment.get('createdTime', '')
comment_id = comment.get("id", "")
author = comment.get("author", {}).get("displayName", "Unknown")
created = comment.get("createdTime", "")
return f"Comment created successfully!\\nComment ID: {comment_id}\\nAuthor: {author}\\nCreated: {created}\\nContent: {comment_content}"
async def _reply_to_comment_impl(service, app_name: str, file_id: str, comment_id: str, reply_content: str) -> str:
async def _reply_to_comment_impl(
service, app_name: str, file_id: str, comment_id: str, reply_content: str
) -> str:
"""Implementation for replying to a comment on any Google Workspace file."""
logger.info(f"[reply_to_{app_name}_comment] Replying to comment {comment_id} in {app_name} {file_id}")
logger.info(
f"[reply_to_{app_name}_comment] Replying to comment {comment_id} in {app_name} {file_id}"
)
body = {'content': reply_content}
body = {"content": reply_content}
reply = await asyncio.to_thread(
service.replies().create(
service.replies()
.create(
fileId=file_id,
commentId=comment_id,
body=body,
fields="id,content,author,createdTime,modifiedTime"
).execute
fields="id,content,author,createdTime,modifiedTime",
)
.execute
)
reply_id = reply.get('id', '')
author = reply.get('author', {}).get('displayName', 'Unknown')
created = reply.get('createdTime', '')
reply_id = reply.get("id", "")
author = reply.get("author", {}).get("displayName", "Unknown")
created = reply.get("createdTime", "")
return f"Reply posted successfully!\\nReply ID: {reply_id}\\nAuthor: {author}\\nCreated: {created}\\nContent: {reply_content}"
async def _resolve_comment_impl(service, app_name: str, file_id: str, comment_id: str) -> str:
async def _resolve_comment_impl(
service, app_name: str, file_id: str, comment_id: str
) -> str:
"""Implementation for resolving a comment on any Google Workspace file."""
logger.info(f"[resolve_{app_name}_comment] Resolving comment {comment_id} in {app_name} {file_id}")
logger.info(
f"[resolve_{app_name}_comment] Resolving comment {comment_id} in {app_name} {file_id}"
)
body = {
"content": "This comment has been resolved.",
"action": "resolve"
}
body = {"content": "This comment has been resolved.", "action": "resolve"}
reply = await asyncio.to_thread(
service.replies().create(
service.replies()
.create(
fileId=file_id,
commentId=comment_id,
body=body,
fields="id,content,author,createdTime,modifiedTime"
).execute
fields="id,content,author,createdTime,modifiedTime",
)
.execute
)
reply_id = reply.get('id', '')
author = reply.get('author', {}).get('displayName', 'Unknown')
created = reply.get('createdTime', '')
reply_id = reply.get("id", "")
author = reply.get("author", {}).get("displayName", "Unknown")
created = reply.get("createdTime", "")
return f"Comment {comment_id} has been resolved successfully.\\nResolve reply ID: {reply_id}\\nAuthor: {author}\\nCreated: {created}"
return f"Comment {comment_id} has been resolved successfully.\\nResolve reply ID: {reply_id}\\nAuthor: {author}\\nCreated: {created}"

View File

@@ -13,7 +13,7 @@ from auth.oauth_config import (
get_oauth_redirect_uri,
set_transport_mode,
get_transport_mode,
is_oauth21_enabled
is_oauth21_enabled,
)
# Server configuration
@@ -21,15 +21,17 @@ WORKSPACE_MCP_PORT = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)
WORKSPACE_MCP_BASE_URI = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
# Disable USER_GOOGLE_EMAIL in OAuth 2.1 multi-user mode
USER_GOOGLE_EMAIL = None if is_oauth21_enabled() else os.getenv("USER_GOOGLE_EMAIL", None)
USER_GOOGLE_EMAIL = (
None if is_oauth21_enabled() else os.getenv("USER_GOOGLE_EMAIL", None)
)
# Re-export OAuth functions for backward compatibility
__all__ = [
'WORKSPACE_MCP_PORT',
'WORKSPACE_MCP_BASE_URI',
'USER_GOOGLE_EMAIL',
'get_oauth_base_url',
'get_oauth_redirect_uri',
'set_transport_mode',
'get_transport_mode'
]
"WORKSPACE_MCP_PORT",
"WORKSPACE_MCP_BASE_URI",
"USER_GOOGLE_EMAIL",
"get_oauth_base_url",
"get_oauth_redirect_uri",
"set_transport_mode",
"get_transport_mode",
]

View File

@@ -8,9 +8,8 @@ _injected_oauth_credentials = contextvars.ContextVar(
)
# Context variable to hold FastMCP session ID for the life of a single request.
_fastmcp_session_id = contextvars.ContextVar(
"fastmcp_session_id", default=None
)
_fastmcp_session_id = contextvars.ContextVar("fastmcp_session_id", default=None)
def get_injected_oauth_credentials():
"""
@@ -19,6 +18,7 @@ def get_injected_oauth_credentials():
"""
return _injected_oauth_credentials.get()
def set_injected_oauth_credentials(credentials: Optional[dict]):
"""
Set or clear the injected OAuth credentials for the current request context.
@@ -26,6 +26,7 @@ def set_injected_oauth_credentials(credentials: Optional[dict]):
"""
_injected_oauth_credentials.set(credentials)
def get_fastmcp_session_id() -> Optional[str]:
"""
Retrieve the FastMCP session ID for the current request context.
@@ -33,9 +34,10 @@ def get_fastmcp_session_id() -> Optional[str]:
"""
return _fastmcp_session_id.get()
def set_fastmcp_session_id(session_id: Optional[str]):
"""
Set or clear the FastMCP session ID for the current request context.
This is called when a FastMCP request starts.
"""
_fastmcp_session_id.set(session_id)
_fastmcp_session_id.set(session_id)

View File

@@ -4,6 +4,7 @@ Enhanced Log Formatter for Google Workspace MCP
Provides visually appealing log formatting with emojis and consistent styling
to match the safe_print output format.
"""
import logging
import os
import re
@@ -15,12 +16,12 @@ class EnhancedLogFormatter(logging.Formatter):
# Color codes for terminals that support ANSI colors
COLORS = {
'DEBUG': '\033[36m', # Cyan
'INFO': '\033[32m', # Green
'WARNING': '\033[33m', # Yellow
'ERROR': '\033[31m', # Red
'CRITICAL': '\033[35m', # Magenta
'RESET': '\033[0m' # Reset
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[35m", # Magenta
"RESET": "\033[0m", # Reset
}
def __init__(self, use_colors: bool = True, *args, **kwargs):
@@ -43,8 +44,8 @@ class EnhancedLogFormatter(logging.Formatter):
# Build the formatted log entry
if self.use_colors:
color = self.COLORS.get(record.levelname, '')
reset = self.COLORS['RESET']
color = self.COLORS.get(record.levelname, "")
reset = self.COLORS["RESET"]
return f"{service_prefix} {color}{formatted_msg}{reset}"
else:
return f"{service_prefix} {formatted_msg}"
@@ -53,25 +54,25 @@ class EnhancedLogFormatter(logging.Formatter):
"""Get ASCII-safe prefix for Windows compatibility."""
# ASCII-safe prefixes for different services
ascii_prefixes = {
'core.tool_tier_loader': '[TOOLS]',
'core.tool_registry': '[REGISTRY]',
'auth.scopes': '[AUTH]',
'core.utils': '[UTILS]',
'auth.google_auth': '[OAUTH]',
'auth.credential_store': '[CREDS]',
'gcalendar.calendar_tools': '[CALENDAR]',
'gdrive.drive_tools': '[DRIVE]',
'gmail.gmail_tools': '[GMAIL]',
'gdocs.docs_tools': '[DOCS]',
'gsheets.sheets_tools': '[SHEETS]',
'gchat.chat_tools': '[CHAT]',
'gforms.forms_tools': '[FORMS]',
'gslides.slides_tools': '[SLIDES]',
'gtasks.tasks_tools': '[TASKS]',
'gsearch.search_tools': '[SEARCH]'
"core.tool_tier_loader": "[TOOLS]",
"core.tool_registry": "[REGISTRY]",
"auth.scopes": "[AUTH]",
"core.utils": "[UTILS]",
"auth.google_auth": "[OAUTH]",
"auth.credential_store": "[CREDS]",
"gcalendar.calendar_tools": "[CALENDAR]",
"gdrive.drive_tools": "[DRIVE]",
"gmail.gmail_tools": "[GMAIL]",
"gdocs.docs_tools": "[DOCS]",
"gsheets.sheets_tools": "[SHEETS]",
"gchat.chat_tools": "[CHAT]",
"gforms.forms_tools": "[FORMS]",
"gslides.slides_tools": "[SLIDES]",
"gtasks.tasks_tools": "[TASKS]",
"gsearch.search_tools": "[SEARCH]",
}
return ascii_prefixes.get(logger_name, f'[{level_name}]')
return ascii_prefixes.get(logger_name, f"[{level_name}]")
def _enhance_message(self, message: str) -> str:
"""Enhance the log message with better formatting."""
@@ -80,7 +81,9 @@ class EnhancedLogFormatter(logging.Formatter):
# Tool tier loading messages
if "resolved to" in message and "tools across" in message:
# Extract numbers and service names for better formatting
pattern = r"Tier '(\w+)' resolved to (\d+) tools across (\d+) services: (.+)"
pattern = (
r"Tier '(\w+)' resolved to (\d+) tools across (\d+) services: (.+)"
)
match = re.search(pattern, message)
if match:
tier, tool_count, service_count, services = match.groups()
@@ -113,7 +116,9 @@ class EnhancedLogFormatter(logging.Formatter):
return message
def setup_enhanced_logging(log_level: int = logging.INFO, use_colors: bool = True) -> None:
def setup_enhanced_logging(
log_level: int = logging.INFO, use_colors: bool = True
) -> None:
"""
Set up enhanced logging with ASCII prefix formatter for the entire application.
@@ -129,12 +134,19 @@ def setup_enhanced_logging(log_level: int = logging.INFO, use_colors: bool = Tru
# Update existing console handlers
for handler in root_logger.handlers:
if isinstance(handler, logging.StreamHandler) and handler.stream.name in ['<stderr>', '<stdout>']:
if isinstance(handler, logging.StreamHandler) and handler.stream.name in [
"<stderr>",
"<stdout>",
]:
handler.setFormatter(formatter)
# If no console handler exists, create one
console_handlers = [h for h in root_logger.handlers
if isinstance(h, logging.StreamHandler) and h.stream.name in ['<stderr>', '<stdout>']]
console_handlers = [
h
for h in root_logger.handlers
if isinstance(h, logging.StreamHandler)
and h.stream.name in ["<stderr>", "<stdout>"]
]
if not console_handlers:
console_handler = logging.StreamHandler()
@@ -157,7 +169,9 @@ def configure_file_logging(logger_name: str = None) -> bool:
bool: True if file logging was configured, False if skipped (stateless mode)
"""
# Check if stateless mode is enabled
stateless_mode = os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
stateless_mode = (
os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
)
if stateless_mode:
logger = logging.getLogger(logger_name)
@@ -170,14 +184,14 @@ def configure_file_logging(logger_name: str = None) -> bool:
log_file_dir = os.path.dirname(os.path.abspath(__file__))
# Go up one level since we're in core/ subdirectory
log_file_dir = os.path.dirname(log_file_dir)
log_file_path = os.path.join(log_file_dir, 'mcp_server_debug.log')
log_file_path = os.path.join(log_file_dir, "mcp_server_debug.log")
file_handler = logging.FileHandler(log_file_path, mode='a')
file_handler = logging.FileHandler(log_file_path, mode="a")
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(process)d - %(threadName)s '
'[%(module)s.%(funcName)s:%(lineno)d] - %(message)s'
"%(asctime)s - %(name)s - %(levelname)s - %(process)d - %(threadName)s "
"[%(module)s.%(funcName)s:%(lineno)d] - %(message)s"
)
file_handler.setFormatter(file_formatter)
target_logger.addHandler(file_handler)
@@ -187,5 +201,7 @@ def configure_file_logging(logger_name: str = None) -> bool:
return True
except Exception as e:
sys.stderr.write(f"CRITICAL: Failed to set up file logging to '{log_file_path}': {e}\n")
sys.stderr.write(
f"CRITICAL: Failed to set up file logging to '{log_file_path}': {e}\n"
)
return False

View File

@@ -13,9 +13,13 @@ from fastmcp.server.auth.providers.google import GoogleProvider
from auth.oauth21_session_store import get_oauth21_session_store, set_auth_provider
from auth.google_auth import handle_auth_callback, start_auth_flow, check_client_secrets
from auth.mcp_session_middleware import MCPSessionMiddleware
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
from auth.oauth_responses import (
create_error_response,
create_success_response,
create_server_error_response,
)
from auth.auth_info_middleware import AuthInfoMiddleware
from auth.scopes import SCOPES, get_current_scopes # noqa
from auth.scopes import SCOPES, get_current_scopes # noqa
from core.config import (
USER_GOOGLE_EMAIL,
get_transport_mode,
@@ -31,6 +35,7 @@ _legacy_callback_registered = False
session_middleware = Middleware(MCPSessionMiddleware)
# Custom FastMCP that adds secure middleware stack for OAuth 2.1
class SecureFastMCP(FastMCP):
def streamable_http_app(self) -> "Starlette":
@@ -46,6 +51,7 @@ class SecureFastMCP(FastMCP):
logger.info("Added middleware stack: Session Management")
return app
server = SecureFastMCP(
name="google_workspace",
auth=None,
@@ -69,6 +75,7 @@ def _ensure_legacy_callback_route() -> None:
server.custom_route("/oauth2callback", methods=["GET"])(legacy_oauth2_callback)
_legacy_callback_registered = True
def configure_server_for_http():
"""
Configures the authentication provider for HTTP transport.
@@ -83,6 +90,7 @@ def configure_server_for_http():
# Use centralized OAuth configuration
from auth.oauth_config import get_oauth_config
config = get_oauth_config()
# Check if OAuth 2.1 is enabled via centralized config
@@ -110,8 +118,12 @@ def configure_server_for_http():
)
# Disable protocol-level auth, expect bearer tokens in tool calls
server.auth = None
logger.info("OAuth 2.1 enabled with EXTERNAL provider mode - protocol-level auth disabled")
logger.info("Expecting Authorization bearer tokens in tool call headers")
logger.info(
"OAuth 2.1 enabled with EXTERNAL provider mode - protocol-level auth disabled"
)
logger.info(
"Expecting Authorization bearer tokens in tool call headers"
)
else:
# Standard OAuth 2.1 mode: use FastMCP's GoogleProvider
provider = GoogleProvider(
@@ -123,13 +135,17 @@ def configure_server_for_http():
)
# Enable protocol-level auth
server.auth = provider
logger.info("OAuth 2.1 enabled using FastMCP GoogleProvider with protocol-level auth")
logger.info(
"OAuth 2.1 enabled using FastMCP GoogleProvider with protocol-level auth"
)
# Always set auth provider for token validation in middleware
set_auth_provider(provider)
_auth_provider = provider
except Exception as exc:
logger.error("Failed to initialize FastMCP GoogleProvider: %s", exc, exc_info=True)
logger.error(
"Failed to initialize FastMCP GoogleProvider: %s", exc, exc_info=True
)
raise
else:
logger.info("OAuth 2.0 mode - Server will use legacy authentication.")
@@ -143,53 +159,56 @@ def get_auth_provider() -> Optional[GoogleProvider]:
"""Gets the global authentication provider instance."""
return _auth_provider
@server.custom_route("/health", methods=["GET"])
async def health_check(request: Request):
try:
version = metadata.version("workspace-mcp")
except metadata.PackageNotFoundError:
version = "dev"
return JSONResponse({
"status": "healthy",
"service": "workspace-mcp",
"version": version,
"transport": get_transport_mode()
})
return JSONResponse(
{
"status": "healthy",
"service": "workspace-mcp",
"version": version,
"transport": get_transport_mode(),
}
)
@server.custom_route("/attachments/{file_id}", methods=["GET"])
async def serve_attachment(file_id: str, request: Request):
"""Serve a stored attachment file."""
from core.attachment_storage import get_attachment_storage
storage = get_attachment_storage()
metadata = storage.get_attachment_metadata(file_id)
if not metadata:
return JSONResponse(
{"error": "Attachment not found or expired"},
status_code=404
{"error": "Attachment not found or expired"}, status_code=404
)
file_path = storage.get_attachment_path(file_id)
if not file_path:
return JSONResponse(
{"error": "Attachment file not found"},
status_code=404
)
return JSONResponse({"error": "Attachment file not found"}, status_code=404)
return FileResponse(
path=str(file_path),
filename=metadata["filename"],
media_type=metadata["mime_type"]
media_type=metadata["mime_type"],
)
async def legacy_oauth2_callback(request: Request) -> HTMLResponse:
state = request.query_params.get("state")
code = request.query_params.get("code")
error = request.query_params.get("error")
if error:
msg = f"Authentication failed: Google returned an error: {error}. State: {state}."
msg = (
f"Authentication failed: Google returned an error: {error}. State: {state}."
)
logger.error(msg)
return create_error_response(msg)
@@ -206,17 +225,19 @@ async def legacy_oauth2_callback(request: Request) -> HTMLResponse:
logger.info(f"OAuth callback: Received code (state: {state}).")
mcp_session_id = None
if hasattr(request, 'state') and hasattr(request.state, 'session_id'):
if hasattr(request, "state") and hasattr(request.state, "session_id"):
mcp_session_id = request.state.session_id
verified_user_id, credentials = handle_auth_callback(
scopes=get_current_scopes(),
authorization_response=str(request.url),
redirect_uri=get_oauth_redirect_uri_for_current_mode(),
session_id=mcp_session_id
session_id=mcp_session_id,
)
logger.info(f"OAuth callback: Successfully authenticated user: {verified_user_id}.")
logger.info(
f"OAuth callback: Successfully authenticated user: {verified_user_id}."
)
try:
store = get_oauth21_session_store()
@@ -233,7 +254,9 @@ async def legacy_oauth2_callback(request: Request) -> HTMLResponse:
session_id=f"google-{state}",
mcp_session_id=mcp_session_id,
)
logger.info(f"Stored Google credentials in OAuth 2.1 session store for {verified_user_id}")
logger.info(
f"Stored Google credentials in OAuth 2.1 session store for {verified_user_id}"
)
except Exception as e:
logger.error(f"Failed to store credentials in OAuth 2.1 store: {e}")
@@ -242,8 +265,11 @@ async def legacy_oauth2_callback(request: Request) -> HTMLResponse:
logger.error(f"Error processing OAuth callback: {str(e)}", exc_info=True)
return create_server_error_response(str(e))
@server.tool()
async def start_google_auth(service_name: str, user_google_email: str = USER_GOOGLE_EMAIL) -> str:
async def start_google_auth(
service_name: str, user_google_email: str = USER_GOOGLE_EMAIL
) -> str:
"""
Manually initiate Google OAuth authentication flow.
@@ -268,7 +294,7 @@ async def start_google_auth(service_name: str, user_google_email: str = USER_GOO
auth_message = await start_auth_flow(
user_google_email=user_google_email,
service_name=service_name,
redirect_uri=get_oauth_redirect_uri_for_current_mode()
redirect_uri=get_oauth_redirect_uri_for_current_mode(),
)
return auth_message
except Exception as e:

View File

@@ -13,32 +13,37 @@ logger = logging.getLogger(__name__)
# Global registry of enabled tools
_enabled_tools: Optional[Set[str]] = None
def set_enabled_tools(tool_names: Optional[Set[str]]):
"""Set the globally enabled tools."""
global _enabled_tools
_enabled_tools = tool_names
def get_enabled_tools() -> Optional[Set[str]]:
"""Get the set of enabled tools, or None if all tools are enabled."""
return _enabled_tools
def is_tool_enabled(tool_name: str) -> bool:
"""Check if a specific tool is enabled."""
if _enabled_tools is None:
return True # All tools enabled by default
return tool_name in _enabled_tools
def conditional_tool(server, tool_name: str):
"""
Decorator that conditionally registers a tool based on the enabled tools set.
Args:
server: The FastMCP server instance
tool_name: The name of the tool to register
Returns:
Either the registered tool decorator or a no-op decorator
"""
def decorator(func: Callable) -> Callable:
if is_tool_enabled(tool_name):
logger.debug(f"Registering tool: {tool_name}")
@@ -46,51 +51,55 @@ def conditional_tool(server, tool_name: str):
else:
logger.debug(f"Skipping tool registration: {tool_name}")
return func
return decorator
def wrap_server_tool_method(server):
"""
Track tool registrations and filter them post-registration.
"""
original_tool = server.tool
server._tracked_tools = []
def tracking_tool(*args, **kwargs):
original_decorator = original_tool(*args, **kwargs)
def wrapper_decorator(func: Callable) -> Callable:
tool_name = func.__name__
server._tracked_tools.append(tool_name)
# Always apply the original decorator to register the tool
return original_decorator(func)
return wrapper_decorator
server.tool = tracking_tool
def filter_server_tools(server):
"""Remove disabled tools from the server after registration."""
enabled_tools = get_enabled_tools()
if enabled_tools is None:
return
tools_removed = 0
# Access FastMCP's tool registry via _tool_manager._tools
if hasattr(server, '_tool_manager'):
if hasattr(server, "_tool_manager"):
tool_manager = server._tool_manager
if hasattr(tool_manager, '_tools'):
if hasattr(tool_manager, "_tools"):
tool_registry = tool_manager._tools
tools_to_remove = []
for tool_name in list(tool_registry.keys()):
if not is_tool_enabled(tool_name):
tools_to_remove.append(tool_name)
for tool_name in tools_to_remove:
del tool_registry[tool_name]
tools_removed += 1
if tools_removed > 0:
logger.info(f"Tool tier filtering: removed {tools_removed} tools, {len(enabled_tools)} enabled")
logger.info(
f"Tool tier filtering: removed {tools_removed} tools, {len(enabled_tools)} enabled"
)

View File

@@ -15,33 +15,36 @@ logger = logging.getLogger(__name__)
TierLevel = Literal["core", "extended", "complete"]
class ToolTierLoader:
"""Loads and manages tool tiers from configuration."""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the tool tier loader.
Args:
config_path: Path to the tool_tiers.yaml file. If None, uses default location.
"""
if config_path is None:
# Default to core/tool_tiers.yaml relative to this file
config_path = Path(__file__).parent / "tool_tiers.yaml"
self.config_path = Path(config_path)
self._tiers_config: Optional[Dict] = None
def _load_config(self) -> Dict:
"""Load the tool tiers configuration from YAML file."""
if self._tiers_config is not None:
return self._tiers_config
if not self.config_path.exists():
raise FileNotFoundError(f"Tool tiers configuration not found: {self.config_path}")
raise FileNotFoundError(
f"Tool tiers configuration not found: {self.config_path}"
)
try:
with open(self.config_path, 'r', encoding='utf-8') as f:
with open(self.config_path, "r", encoding="utf-8") as f:
self._tiers_config = yaml.safe_load(f)
logger.info(f"Loaded tool tiers configuration from {self.config_path}")
return self._tiers_config
@@ -49,65 +52,71 @@ class ToolTierLoader:
raise ValueError(f"Invalid YAML in tool tiers configuration: {e}")
except Exception as e:
raise RuntimeError(f"Failed to load tool tiers configuration: {e}")
def get_available_services(self) -> List[str]:
"""Get list of all available services defined in the configuration."""
config = self._load_config()
return list(config.keys())
def get_tools_for_tier(self, tier: TierLevel, services: Optional[List[str]] = None) -> List[str]:
def get_tools_for_tier(
self, tier: TierLevel, services: Optional[List[str]] = None
) -> List[str]:
"""
Get all tools for a specific tier level.
Args:
tier: The tier level (core, extended, complete)
services: Optional list of services to filter by. If None, includes all services.
Returns:
List of tool names for the specified tier level
"""
config = self._load_config()
tools = []
# If no services specified, use all available services
if services is None:
services = self.get_available_services()
for service in services:
if service not in config:
logger.warning(f"Service '{service}' not found in tool tiers configuration")
logger.warning(
f"Service '{service}' not found in tool tiers configuration"
)
continue
service_config = config[service]
if tier not in service_config:
logger.debug(f"Tier '{tier}' not defined for service '{service}'")
continue
tier_tools = service_config[tier]
if tier_tools: # Handle empty lists
tools.extend(tier_tools)
return tools
def get_tools_up_to_tier(self, tier: TierLevel, services: Optional[List[str]] = None) -> List[str]:
def get_tools_up_to_tier(
self, tier: TierLevel, services: Optional[List[str]] = None
) -> List[str]:
"""
Get all tools up to and including the specified tier level.
Args:
tier: The maximum tier level to include
services: Optional list of services to filter by. If None, includes all services.
Returns:
List of tool names up to the specified tier level
"""
tier_order = ["core", "extended", "complete"]
max_tier_index = tier_order.index(tier)
tools = []
for i in range(max_tier_index + 1):
current_tier = tier_order[i]
tools.extend(self.get_tools_for_tier(current_tier, services))
# Remove duplicates while preserving order
seen = set()
unique_tools = []
@@ -115,39 +124,41 @@ class ToolTierLoader:
if tool not in seen:
seen.add(tool)
unique_tools.append(tool)
return unique_tools
def get_services_for_tools(self, tool_names: List[str]) -> Set[str]:
"""
Get the service names that provide the specified tools.
Args:
tool_names: List of tool names to lookup
Returns:
Set of service names that provide any of the specified tools
"""
config = self._load_config()
services = set()
for service, service_config in config.items():
for tier_name, tier_tools in service_config.items():
if tier_tools and any(tool in tier_tools for tool in tool_names):
services.add(service)
break
return services
def get_tools_for_tier(tier: TierLevel, services: Optional[List[str]] = None) -> List[str]:
def get_tools_for_tier(
tier: TierLevel, services: Optional[List[str]] = None
) -> List[str]:
"""
Convenience function to get tools for a specific tier.
Args:
tier: The tier level (core, extended, complete)
services: Optional list of services to filter by
Returns:
List of tool names for the specified tier level
"""
@@ -155,27 +166,31 @@ def get_tools_for_tier(tier: TierLevel, services: Optional[List[str]] = None) ->
return loader.get_tools_up_to_tier(tier, services)
def resolve_tools_from_tier(tier: TierLevel, services: Optional[List[str]] = None) -> tuple[List[str], List[str]]:
def resolve_tools_from_tier(
tier: TierLevel, services: Optional[List[str]] = None
) -> tuple[List[str], List[str]]:
"""
Resolve tool names and service names for the specified tier.
Args:
tier: The tier level (core, extended, complete)
services: Optional list of services to filter by
Returns:
Tuple of (tool_names, service_names) where:
- tool_names: List of specific tool names for the tier
- service_names: List of service names that should be imported
"""
loader = ToolTierLoader()
# Get all tools for the tier
tools = loader.get_tools_up_to_tier(tier, services)
# Map back to service names
service_names = loader.get_services_for_tools(tools)
logger.info(f"Tier '{tier}' resolved to {len(tools)} tools across {len(service_names)} services: {sorted(service_names)}")
return tools, sorted(service_names)
logger.info(
f"Tier '{tier}' resolved to {len(tools)} tools across {len(service_names)} services: {sorted(service_names)}"
)
return tools, sorted(service_names)

View File

@@ -27,6 +27,7 @@ class UserInputError(Exception):
pass
def check_credentials_directory_permissions(credentials_dir: str = None) -> None:
"""
Check if the service has appropriate permissions to create and write to the .credentials directory.
@@ -181,7 +182,7 @@ def extract_office_xml_text(file_bytes: bytes, mime_type: str) -> Optional[str]:
member_texts.append(shared_strings[ss_idx])
else:
logger.warning(
f"Invalid shared string index {ss_idx} in {member}. Max index: {len(shared_strings)-1}"
f"Invalid shared string index {ss_idx} in {member}. Max index: {len(shared_strings) - 1}"
)
except ValueError:
logger.warning(
@@ -240,7 +241,9 @@ def extract_office_xml_text(file_bytes: bytes, mime_type: str) -> Optional[str]:
return None
def handle_http_errors(tool_name: str, is_read_only: bool = False, service_type: Optional[str] = None):
def handle_http_errors(
tool_name: str, is_read_only: bool = False, service_type: Optional[str] = None
):
"""
A decorator to handle Google API HttpErrors and transient SSL errors in a standardized way.
@@ -288,11 +291,16 @@ def handle_http_errors(tool_name: str, is_read_only: bool = False, service_type:
except HttpError as error:
user_google_email = kwargs.get("user_google_email", "N/A")
error_details = str(error)
# Check if this is an API not enabled error
if error.resp.status == 403 and "accessNotConfigured" in error_details:
enablement_msg = get_api_enablement_message(error_details, service_type)
if (
error.resp.status == 403
and "accessNotConfigured" in error_details
):
enablement_msg = get_api_enablement_message(
error_details, service_type
)
if enablement_msg:
message = (
f"API error in {tool_name}: {enablement_msg}\n\n"
@@ -314,7 +322,7 @@ def handle_http_errors(tool_name: str, is_read_only: bool = False, service_type:
else:
# Other HTTP errors (400 Bad Request, etc.) - don't suggest re-auth
message = f"API error in {tool_name}: {error}"
logger.error(f"API error in {tool_name}: {error}", exc_info=True)
raise Exception(message) from error
except TransientNetworkError: