apply ruff formatting
This commit is contained in:
@@ -45,15 +45,17 @@ INTERNAL_SERVICE_TO_API: Dict[str, str] = {
|
||||
}
|
||||
|
||||
|
||||
def extract_api_info_from_error(error_details: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
def extract_api_info_from_error(
|
||||
error_details: str,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract API service and project ID from error details.
|
||||
|
||||
Returns:
|
||||
Tuple of (api_service, project_id) or (None, None) if not found
|
||||
"""
|
||||
api_pattern = r'https://console\.developers\.google\.com/apis/api/([^/]+)/overview'
|
||||
project_pattern = r'project[=\s]+([a-zA-Z0-9-]+)'
|
||||
api_pattern = r"https://console\.developers\.google\.com/apis/api/([^/]+)/overview"
|
||||
project_pattern = r"project[=\s]+([a-zA-Z0-9-]+)"
|
||||
|
||||
api_match = re.search(api_pattern, error_details)
|
||||
project_match = re.search(project_pattern, error_details)
|
||||
@@ -64,7 +66,9 @@ def extract_api_info_from_error(error_details: str) -> Tuple[Optional[str], Opti
|
||||
return api_service, project_id
|
||||
|
||||
|
||||
def get_api_enablement_message(error_details: str, service_type: Optional[str] = None) -> str:
|
||||
def get_api_enablement_message(
|
||||
error_details: str, service_type: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a helpful error message with direct API enablement link.
|
||||
|
||||
@@ -88,7 +92,7 @@ def get_api_enablement_message(error_details: str, service_type: Optional[str] =
|
||||
enable_link = API_ENABLEMENT_LINKS[api_service]
|
||||
service_display_name = next(
|
||||
(name for name, api in SERVICE_NAME_TO_API.items() if api == api_service),
|
||||
api_service
|
||||
api_service,
|
||||
)
|
||||
|
||||
message = (
|
||||
@@ -101,4 +105,4 @@ def get_api_enablement_message(error_details: str, service_type: Optional[str] =
|
||||
|
||||
return message
|
||||
|
||||
return ""
|
||||
return ""
|
||||
|
||||
@@ -77,7 +77,9 @@ class AttachmentStorage:
|
||||
file_path = STORAGE_DIR / f"{file_id}{extension}"
|
||||
try:
|
||||
file_path.write_bytes(file_bytes)
|
||||
logger.info(f"Saved attachment {file_id} ({len(file_bytes)} bytes) to {file_path}")
|
||||
logger.info(
|
||||
f"Saved attachment {file_id} ({len(file_bytes)} bytes) to {file_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save attachment to {file_path}: {e}")
|
||||
raise
|
||||
@@ -213,4 +215,3 @@ def get_attachment_url(file_id: str) -> str:
|
||||
base_url = f"{WORKSPACE_MCP_BASE_URI}:{WORKSPACE_MCP_PORT}"
|
||||
|
||||
return f"{base_url}/attachments/{file_id}"
|
||||
|
||||
|
||||
208
core/comments.py
208
core/comments.py
@@ -36,79 +36,136 @@ def create_comment_tools(app_name: str, file_id_param: str):
|
||||
|
||||
# Create functions without decorators first, then apply decorators with proper names
|
||||
if file_id_param == "document_id":
|
||||
|
||||
@require_google_service("drive", "drive_read")
|
||||
@handle_http_errors(read_func_name, service_type="drive")
|
||||
async def read_comments(service, user_google_email: str, document_id: str) -> str:
|
||||
async def read_comments(
|
||||
service, user_google_email: str, document_id: str
|
||||
) -> str:
|
||||
"""Read all comments from a Google Document."""
|
||||
return await _read_comments_impl(service, app_name, document_id)
|
||||
|
||||
@require_google_service("drive", "drive_file")
|
||||
@handle_http_errors(create_func_name, service_type="drive")
|
||||
async def create_comment(service, user_google_email: str, document_id: str, comment_content: str) -> str:
|
||||
async def create_comment(
|
||||
service, user_google_email: str, document_id: str, comment_content: str
|
||||
) -> str:
|
||||
"""Create a new comment on a Google Document."""
|
||||
return await _create_comment_impl(service, app_name, document_id, comment_content)
|
||||
return await _create_comment_impl(
|
||||
service, app_name, document_id, comment_content
|
||||
)
|
||||
|
||||
@require_google_service("drive", "drive_file")
|
||||
@handle_http_errors(reply_func_name, service_type="drive")
|
||||
async def reply_to_comment(service, user_google_email: str, document_id: str, comment_id: str, reply_content: str) -> str:
|
||||
async def reply_to_comment(
|
||||
service,
|
||||
user_google_email: str,
|
||||
document_id: str,
|
||||
comment_id: str,
|
||||
reply_content: str,
|
||||
) -> str:
|
||||
"""Reply to a specific comment in a Google Document."""
|
||||
return await _reply_to_comment_impl(service, app_name, document_id, comment_id, reply_content)
|
||||
return await _reply_to_comment_impl(
|
||||
service, app_name, document_id, comment_id, reply_content
|
||||
)
|
||||
|
||||
@require_google_service("drive", "drive_file")
|
||||
@handle_http_errors(resolve_func_name, service_type="drive")
|
||||
async def resolve_comment(service, user_google_email: str, document_id: str, comment_id: str) -> str:
|
||||
async def resolve_comment(
|
||||
service, user_google_email: str, document_id: str, comment_id: str
|
||||
) -> str:
|
||||
"""Resolve a comment in a Google Document."""
|
||||
return await _resolve_comment_impl(service, app_name, document_id, comment_id)
|
||||
return await _resolve_comment_impl(
|
||||
service, app_name, document_id, comment_id
|
||||
)
|
||||
|
||||
elif file_id_param == "spreadsheet_id":
|
||||
|
||||
@require_google_service("drive", "drive_read")
|
||||
@handle_http_errors(read_func_name, service_type="drive")
|
||||
async def read_comments(service, user_google_email: str, spreadsheet_id: str) -> str:
|
||||
async def read_comments(
|
||||
service, user_google_email: str, spreadsheet_id: str
|
||||
) -> str:
|
||||
"""Read all comments from a Google Spreadsheet."""
|
||||
return await _read_comments_impl(service, app_name, spreadsheet_id)
|
||||
|
||||
@require_google_service("drive", "drive_file")
|
||||
@handle_http_errors(create_func_name, service_type="drive")
|
||||
async def create_comment(service, user_google_email: str, spreadsheet_id: str, comment_content: str) -> str:
|
||||
async def create_comment(
|
||||
service, user_google_email: str, spreadsheet_id: str, comment_content: str
|
||||
) -> str:
|
||||
"""Create a new comment on a Google Spreadsheet."""
|
||||
return await _create_comment_impl(service, app_name, spreadsheet_id, comment_content)
|
||||
return await _create_comment_impl(
|
||||
service, app_name, spreadsheet_id, comment_content
|
||||
)
|
||||
|
||||
@require_google_service("drive", "drive_file")
|
||||
@handle_http_errors(reply_func_name, service_type="drive")
|
||||
async def reply_to_comment(service, user_google_email: str, spreadsheet_id: str, comment_id: str, reply_content: str) -> str:
|
||||
async def reply_to_comment(
|
||||
service,
|
||||
user_google_email: str,
|
||||
spreadsheet_id: str,
|
||||
comment_id: str,
|
||||
reply_content: str,
|
||||
) -> str:
|
||||
"""Reply to a specific comment in a Google Spreadsheet."""
|
||||
return await _reply_to_comment_impl(service, app_name, spreadsheet_id, comment_id, reply_content)
|
||||
return await _reply_to_comment_impl(
|
||||
service, app_name, spreadsheet_id, comment_id, reply_content
|
||||
)
|
||||
|
||||
@require_google_service("drive", "drive_file")
|
||||
@handle_http_errors(resolve_func_name, service_type="drive")
|
||||
async def resolve_comment(service, user_google_email: str, spreadsheet_id: str, comment_id: str) -> str:
|
||||
async def resolve_comment(
|
||||
service, user_google_email: str, spreadsheet_id: str, comment_id: str
|
||||
) -> str:
|
||||
"""Resolve a comment in a Google Spreadsheet."""
|
||||
return await _resolve_comment_impl(service, app_name, spreadsheet_id, comment_id)
|
||||
return await _resolve_comment_impl(
|
||||
service, app_name, spreadsheet_id, comment_id
|
||||
)
|
||||
|
||||
elif file_id_param == "presentation_id":
|
||||
|
||||
@require_google_service("drive", "drive_read")
|
||||
@handle_http_errors(read_func_name, service_type="drive")
|
||||
async def read_comments(service, user_google_email: str, presentation_id: str) -> str:
|
||||
async def read_comments(
|
||||
service, user_google_email: str, presentation_id: str
|
||||
) -> str:
|
||||
"""Read all comments from a Google Presentation."""
|
||||
return await _read_comments_impl(service, app_name, presentation_id)
|
||||
|
||||
@require_google_service("drive", "drive_file")
|
||||
@handle_http_errors(create_func_name, service_type="drive")
|
||||
async def create_comment(service, user_google_email: str, presentation_id: str, comment_content: str) -> str:
|
||||
async def create_comment(
|
||||
service, user_google_email: str, presentation_id: str, comment_content: str
|
||||
) -> str:
|
||||
"""Create a new comment on a Google Presentation."""
|
||||
return await _create_comment_impl(service, app_name, presentation_id, comment_content)
|
||||
return await _create_comment_impl(
|
||||
service, app_name, presentation_id, comment_content
|
||||
)
|
||||
|
||||
@require_google_service("drive", "drive_file")
|
||||
@handle_http_errors(reply_func_name, service_type="drive")
|
||||
async def reply_to_comment(service, user_google_email: str, presentation_id: str, comment_id: str, reply_content: str) -> str:
|
||||
async def reply_to_comment(
|
||||
service,
|
||||
user_google_email: str,
|
||||
presentation_id: str,
|
||||
comment_id: str,
|
||||
reply_content: str,
|
||||
) -> str:
|
||||
"""Reply to a specific comment in a Google Presentation."""
|
||||
return await _reply_to_comment_impl(service, app_name, presentation_id, comment_id, reply_content)
|
||||
return await _reply_to_comment_impl(
|
||||
service, app_name, presentation_id, comment_id, reply_content
|
||||
)
|
||||
|
||||
@require_google_service("drive", "drive_file")
|
||||
@handle_http_errors(resolve_func_name, service_type="drive")
|
||||
async def resolve_comment(service, user_google_email: str, presentation_id: str, comment_id: str) -> str:
|
||||
async def resolve_comment(
|
||||
service, user_google_email: str, presentation_id: str, comment_id: str
|
||||
) -> str:
|
||||
"""Resolve a comment in a Google Presentation."""
|
||||
return await _resolve_comment_impl(service, app_name, presentation_id, comment_id)
|
||||
return await _resolve_comment_impl(
|
||||
service, app_name, presentation_id, comment_id
|
||||
)
|
||||
|
||||
# Set the proper function names and register with server
|
||||
read_comments.__name__ = read_func_name
|
||||
@@ -123,10 +180,10 @@ def create_comment_tools(app_name: str, file_id_param: str):
|
||||
server.tool()(resolve_comment)
|
||||
|
||||
return {
|
||||
'read_comments': read_comments,
|
||||
'create_comment': create_comment,
|
||||
'reply_to_comment': reply_to_comment,
|
||||
'resolve_comment': resolve_comment
|
||||
"read_comments": read_comments,
|
||||
"create_comment": create_comment,
|
||||
"reply_to_comment": reply_to_comment,
|
||||
"resolve_comment": resolve_comment,
|
||||
}
|
||||
|
||||
|
||||
@@ -135,13 +192,15 @@ async def _read_comments_impl(service, app_name: str, file_id: str) -> str:
|
||||
logger.info(f"[read_{app_name}_comments] Reading comments for {app_name} {file_id}")
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
service.comments().list(
|
||||
service.comments()
|
||||
.list(
|
||||
fileId=file_id,
|
||||
fields="comments(id,content,author,createdTime,modifiedTime,resolved,replies(content,author,id,createdTime,modifiedTime))"
|
||||
).execute
|
||||
fields="comments(id,content,author,createdTime,modifiedTime,resolved,replies(content,author,id,createdTime,modifiedTime))",
|
||||
)
|
||||
.execute
|
||||
)
|
||||
|
||||
comments = response.get('comments', [])
|
||||
comments = response.get("comments", [])
|
||||
|
||||
if not comments:
|
||||
return f"No comments found in {app_name} {file_id}"
|
||||
@@ -149,11 +208,11 @@ async def _read_comments_impl(service, app_name: str, file_id: str) -> str:
|
||||
output = [f"Found {len(comments)} comments in {app_name} {file_id}:\\n"]
|
||||
|
||||
for comment in comments:
|
||||
author = comment.get('author', {}).get('displayName', 'Unknown')
|
||||
content = comment.get('content', '')
|
||||
created = comment.get('createdTime', '')
|
||||
resolved = comment.get('resolved', False)
|
||||
comment_id = comment.get('id', '')
|
||||
author = comment.get("author", {}).get("displayName", "Unknown")
|
||||
content = comment.get("content", "")
|
||||
created = comment.get("createdTime", "")
|
||||
resolved = comment.get("resolved", False)
|
||||
comment_id = comment.get("id", "")
|
||||
status = " [RESOLVED]" if resolved else ""
|
||||
|
||||
output.append(f"Comment ID: {comment_id}")
|
||||
@@ -162,14 +221,14 @@ async def _read_comments_impl(service, app_name: str, file_id: str) -> str:
|
||||
output.append(f"Content: {content}")
|
||||
|
||||
# Add replies if any
|
||||
replies = comment.get('replies', [])
|
||||
replies = comment.get("replies", [])
|
||||
if replies:
|
||||
output.append(f" Replies ({len(replies)}):")
|
||||
for reply in replies:
|
||||
reply_author = reply.get('author', {}).get('displayName', 'Unknown')
|
||||
reply_content = reply.get('content', '')
|
||||
reply_created = reply.get('createdTime', '')
|
||||
reply_id = reply.get('id', '')
|
||||
reply_author = reply.get("author", {}).get("displayName", "Unknown")
|
||||
reply_content = reply.get("content", "")
|
||||
reply_created = reply.get("createdTime", "")
|
||||
reply_id = reply.get("id", "")
|
||||
output.append(f" Reply ID: {reply_id}")
|
||||
output.append(f" Author: {reply_author}")
|
||||
output.append(f" Created: {reply_created}")
|
||||
@@ -180,69 +239,82 @@ async def _read_comments_impl(service, app_name: str, file_id: str) -> str:
|
||||
return "\\n".join(output)
|
||||
|
||||
|
||||
async def _create_comment_impl(service, app_name: str, file_id: str, comment_content: str) -> str:
|
||||
async def _create_comment_impl(
|
||||
service, app_name: str, file_id: str, comment_content: str
|
||||
) -> str:
|
||||
"""Implementation for creating a comment on any Google Workspace file."""
|
||||
logger.info(f"[create_{app_name}_comment] Creating comment in {app_name} {file_id}")
|
||||
|
||||
body = {"content": comment_content}
|
||||
|
||||
comment = await asyncio.to_thread(
|
||||
service.comments().create(
|
||||
service.comments()
|
||||
.create(
|
||||
fileId=file_id,
|
||||
body=body,
|
||||
fields="id,content,author,createdTime,modifiedTime"
|
||||
).execute
|
||||
fields="id,content,author,createdTime,modifiedTime",
|
||||
)
|
||||
.execute
|
||||
)
|
||||
|
||||
comment_id = comment.get('id', '')
|
||||
author = comment.get('author', {}).get('displayName', 'Unknown')
|
||||
created = comment.get('createdTime', '')
|
||||
comment_id = comment.get("id", "")
|
||||
author = comment.get("author", {}).get("displayName", "Unknown")
|
||||
created = comment.get("createdTime", "")
|
||||
|
||||
return f"Comment created successfully!\\nComment ID: {comment_id}\\nAuthor: {author}\\nCreated: {created}\\nContent: {comment_content}"
|
||||
|
||||
|
||||
async def _reply_to_comment_impl(service, app_name: str, file_id: str, comment_id: str, reply_content: str) -> str:
|
||||
async def _reply_to_comment_impl(
|
||||
service, app_name: str, file_id: str, comment_id: str, reply_content: str
|
||||
) -> str:
|
||||
"""Implementation for replying to a comment on any Google Workspace file."""
|
||||
logger.info(f"[reply_to_{app_name}_comment] Replying to comment {comment_id} in {app_name} {file_id}")
|
||||
logger.info(
|
||||
f"[reply_to_{app_name}_comment] Replying to comment {comment_id} in {app_name} {file_id}"
|
||||
)
|
||||
|
||||
body = {'content': reply_content}
|
||||
body = {"content": reply_content}
|
||||
|
||||
reply = await asyncio.to_thread(
|
||||
service.replies().create(
|
||||
service.replies()
|
||||
.create(
|
||||
fileId=file_id,
|
||||
commentId=comment_id,
|
||||
body=body,
|
||||
fields="id,content,author,createdTime,modifiedTime"
|
||||
).execute
|
||||
fields="id,content,author,createdTime,modifiedTime",
|
||||
)
|
||||
.execute
|
||||
)
|
||||
|
||||
reply_id = reply.get('id', '')
|
||||
author = reply.get('author', {}).get('displayName', 'Unknown')
|
||||
created = reply.get('createdTime', '')
|
||||
reply_id = reply.get("id", "")
|
||||
author = reply.get("author", {}).get("displayName", "Unknown")
|
||||
created = reply.get("createdTime", "")
|
||||
|
||||
return f"Reply posted successfully!\\nReply ID: {reply_id}\\nAuthor: {author}\\nCreated: {created}\\nContent: {reply_content}"
|
||||
|
||||
|
||||
async def _resolve_comment_impl(service, app_name: str, file_id: str, comment_id: str) -> str:
|
||||
async def _resolve_comment_impl(
|
||||
service, app_name: str, file_id: str, comment_id: str
|
||||
) -> str:
|
||||
"""Implementation for resolving a comment on any Google Workspace file."""
|
||||
logger.info(f"[resolve_{app_name}_comment] Resolving comment {comment_id} in {app_name} {file_id}")
|
||||
logger.info(
|
||||
f"[resolve_{app_name}_comment] Resolving comment {comment_id} in {app_name} {file_id}"
|
||||
)
|
||||
|
||||
body = {
|
||||
"content": "This comment has been resolved.",
|
||||
"action": "resolve"
|
||||
}
|
||||
body = {"content": "This comment has been resolved.", "action": "resolve"}
|
||||
|
||||
reply = await asyncio.to_thread(
|
||||
service.replies().create(
|
||||
service.replies()
|
||||
.create(
|
||||
fileId=file_id,
|
||||
commentId=comment_id,
|
||||
body=body,
|
||||
fields="id,content,author,createdTime,modifiedTime"
|
||||
).execute
|
||||
fields="id,content,author,createdTime,modifiedTime",
|
||||
)
|
||||
.execute
|
||||
)
|
||||
|
||||
reply_id = reply.get('id', '')
|
||||
author = reply.get('author', {}).get('displayName', 'Unknown')
|
||||
created = reply.get('createdTime', '')
|
||||
reply_id = reply.get("id", "")
|
||||
author = reply.get("author", {}).get("displayName", "Unknown")
|
||||
created = reply.get("createdTime", "")
|
||||
|
||||
return f"Comment {comment_id} has been resolved successfully.\\nResolve reply ID: {reply_id}\\nAuthor: {author}\\nCreated: {created}"
|
||||
return f"Comment {comment_id} has been resolved successfully.\\nResolve reply ID: {reply_id}\\nAuthor: {author}\\nCreated: {created}"
|
||||
|
||||
@@ -13,7 +13,7 @@ from auth.oauth_config import (
|
||||
get_oauth_redirect_uri,
|
||||
set_transport_mode,
|
||||
get_transport_mode,
|
||||
is_oauth21_enabled
|
||||
is_oauth21_enabled,
|
||||
)
|
||||
|
||||
# Server configuration
|
||||
@@ -21,15 +21,17 @@ WORKSPACE_MCP_PORT = int(os.getenv("PORT", os.getenv("WORKSPACE_MCP_PORT", 8000)
|
||||
WORKSPACE_MCP_BASE_URI = os.getenv("WORKSPACE_MCP_BASE_URI", "http://localhost")
|
||||
|
||||
# Disable USER_GOOGLE_EMAIL in OAuth 2.1 multi-user mode
|
||||
USER_GOOGLE_EMAIL = None if is_oauth21_enabled() else os.getenv("USER_GOOGLE_EMAIL", None)
|
||||
USER_GOOGLE_EMAIL = (
|
||||
None if is_oauth21_enabled() else os.getenv("USER_GOOGLE_EMAIL", None)
|
||||
)
|
||||
|
||||
# Re-export OAuth functions for backward compatibility
|
||||
__all__ = [
|
||||
'WORKSPACE_MCP_PORT',
|
||||
'WORKSPACE_MCP_BASE_URI',
|
||||
'USER_GOOGLE_EMAIL',
|
||||
'get_oauth_base_url',
|
||||
'get_oauth_redirect_uri',
|
||||
'set_transport_mode',
|
||||
'get_transport_mode'
|
||||
]
|
||||
"WORKSPACE_MCP_PORT",
|
||||
"WORKSPACE_MCP_BASE_URI",
|
||||
"USER_GOOGLE_EMAIL",
|
||||
"get_oauth_base_url",
|
||||
"get_oauth_redirect_uri",
|
||||
"set_transport_mode",
|
||||
"get_transport_mode",
|
||||
]
|
||||
|
||||
@@ -8,9 +8,8 @@ _injected_oauth_credentials = contextvars.ContextVar(
|
||||
)
|
||||
|
||||
# Context variable to hold FastMCP session ID for the life of a single request.
|
||||
_fastmcp_session_id = contextvars.ContextVar(
|
||||
"fastmcp_session_id", default=None
|
||||
)
|
||||
_fastmcp_session_id = contextvars.ContextVar("fastmcp_session_id", default=None)
|
||||
|
||||
|
||||
def get_injected_oauth_credentials():
|
||||
"""
|
||||
@@ -19,6 +18,7 @@ def get_injected_oauth_credentials():
|
||||
"""
|
||||
return _injected_oauth_credentials.get()
|
||||
|
||||
|
||||
def set_injected_oauth_credentials(credentials: Optional[dict]):
|
||||
"""
|
||||
Set or clear the injected OAuth credentials for the current request context.
|
||||
@@ -26,6 +26,7 @@ def set_injected_oauth_credentials(credentials: Optional[dict]):
|
||||
"""
|
||||
_injected_oauth_credentials.set(credentials)
|
||||
|
||||
|
||||
def get_fastmcp_session_id() -> Optional[str]:
|
||||
"""
|
||||
Retrieve the FastMCP session ID for the current request context.
|
||||
@@ -33,9 +34,10 @@ def get_fastmcp_session_id() -> Optional[str]:
|
||||
"""
|
||||
return _fastmcp_session_id.get()
|
||||
|
||||
|
||||
def set_fastmcp_session_id(session_id: Optional[str]):
|
||||
"""
|
||||
Set or clear the FastMCP session ID for the current request context.
|
||||
This is called when a FastMCP request starts.
|
||||
"""
|
||||
_fastmcp_session_id.set(session_id)
|
||||
_fastmcp_session_id.set(session_id)
|
||||
|
||||
@@ -4,6 +4,7 @@ Enhanced Log Formatter for Google Workspace MCP
|
||||
Provides visually appealing log formatting with emojis and consistent styling
|
||||
to match the safe_print output format.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -15,12 +16,12 @@ class EnhancedLogFormatter(logging.Formatter):
|
||||
|
||||
# Color codes for terminals that support ANSI colors
|
||||
COLORS = {
|
||||
'DEBUG': '\033[36m', # Cyan
|
||||
'INFO': '\033[32m', # Green
|
||||
'WARNING': '\033[33m', # Yellow
|
||||
'ERROR': '\033[31m', # Red
|
||||
'CRITICAL': '\033[35m', # Magenta
|
||||
'RESET': '\033[0m' # Reset
|
||||
"DEBUG": "\033[36m", # Cyan
|
||||
"INFO": "\033[32m", # Green
|
||||
"WARNING": "\033[33m", # Yellow
|
||||
"ERROR": "\033[31m", # Red
|
||||
"CRITICAL": "\033[35m", # Magenta
|
||||
"RESET": "\033[0m", # Reset
|
||||
}
|
||||
|
||||
def __init__(self, use_colors: bool = True, *args, **kwargs):
|
||||
@@ -43,8 +44,8 @@ class EnhancedLogFormatter(logging.Formatter):
|
||||
|
||||
# Build the formatted log entry
|
||||
if self.use_colors:
|
||||
color = self.COLORS.get(record.levelname, '')
|
||||
reset = self.COLORS['RESET']
|
||||
color = self.COLORS.get(record.levelname, "")
|
||||
reset = self.COLORS["RESET"]
|
||||
return f"{service_prefix} {color}{formatted_msg}{reset}"
|
||||
else:
|
||||
return f"{service_prefix} {formatted_msg}"
|
||||
@@ -53,25 +54,25 @@ class EnhancedLogFormatter(logging.Formatter):
|
||||
"""Get ASCII-safe prefix for Windows compatibility."""
|
||||
# ASCII-safe prefixes for different services
|
||||
ascii_prefixes = {
|
||||
'core.tool_tier_loader': '[TOOLS]',
|
||||
'core.tool_registry': '[REGISTRY]',
|
||||
'auth.scopes': '[AUTH]',
|
||||
'core.utils': '[UTILS]',
|
||||
'auth.google_auth': '[OAUTH]',
|
||||
'auth.credential_store': '[CREDS]',
|
||||
'gcalendar.calendar_tools': '[CALENDAR]',
|
||||
'gdrive.drive_tools': '[DRIVE]',
|
||||
'gmail.gmail_tools': '[GMAIL]',
|
||||
'gdocs.docs_tools': '[DOCS]',
|
||||
'gsheets.sheets_tools': '[SHEETS]',
|
||||
'gchat.chat_tools': '[CHAT]',
|
||||
'gforms.forms_tools': '[FORMS]',
|
||||
'gslides.slides_tools': '[SLIDES]',
|
||||
'gtasks.tasks_tools': '[TASKS]',
|
||||
'gsearch.search_tools': '[SEARCH]'
|
||||
"core.tool_tier_loader": "[TOOLS]",
|
||||
"core.tool_registry": "[REGISTRY]",
|
||||
"auth.scopes": "[AUTH]",
|
||||
"core.utils": "[UTILS]",
|
||||
"auth.google_auth": "[OAUTH]",
|
||||
"auth.credential_store": "[CREDS]",
|
||||
"gcalendar.calendar_tools": "[CALENDAR]",
|
||||
"gdrive.drive_tools": "[DRIVE]",
|
||||
"gmail.gmail_tools": "[GMAIL]",
|
||||
"gdocs.docs_tools": "[DOCS]",
|
||||
"gsheets.sheets_tools": "[SHEETS]",
|
||||
"gchat.chat_tools": "[CHAT]",
|
||||
"gforms.forms_tools": "[FORMS]",
|
||||
"gslides.slides_tools": "[SLIDES]",
|
||||
"gtasks.tasks_tools": "[TASKS]",
|
||||
"gsearch.search_tools": "[SEARCH]",
|
||||
}
|
||||
|
||||
return ascii_prefixes.get(logger_name, f'[{level_name}]')
|
||||
return ascii_prefixes.get(logger_name, f"[{level_name}]")
|
||||
|
||||
def _enhance_message(self, message: str) -> str:
|
||||
"""Enhance the log message with better formatting."""
|
||||
@@ -80,7 +81,9 @@ class EnhancedLogFormatter(logging.Formatter):
|
||||
# Tool tier loading messages
|
||||
if "resolved to" in message and "tools across" in message:
|
||||
# Extract numbers and service names for better formatting
|
||||
pattern = r"Tier '(\w+)' resolved to (\d+) tools across (\d+) services: (.+)"
|
||||
pattern = (
|
||||
r"Tier '(\w+)' resolved to (\d+) tools across (\d+) services: (.+)"
|
||||
)
|
||||
match = re.search(pattern, message)
|
||||
if match:
|
||||
tier, tool_count, service_count, services = match.groups()
|
||||
@@ -113,7 +116,9 @@ class EnhancedLogFormatter(logging.Formatter):
|
||||
return message
|
||||
|
||||
|
||||
def setup_enhanced_logging(log_level: int = logging.INFO, use_colors: bool = True) -> None:
|
||||
def setup_enhanced_logging(
|
||||
log_level: int = logging.INFO, use_colors: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Set up enhanced logging with ASCII prefix formatter for the entire application.
|
||||
|
||||
@@ -129,12 +134,19 @@ def setup_enhanced_logging(log_level: int = logging.INFO, use_colors: bool = Tru
|
||||
|
||||
# Update existing console handlers
|
||||
for handler in root_logger.handlers:
|
||||
if isinstance(handler, logging.StreamHandler) and handler.stream.name in ['<stderr>', '<stdout>']:
|
||||
if isinstance(handler, logging.StreamHandler) and handler.stream.name in [
|
||||
"<stderr>",
|
||||
"<stdout>",
|
||||
]:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# If no console handler exists, create one
|
||||
console_handlers = [h for h in root_logger.handlers
|
||||
if isinstance(h, logging.StreamHandler) and h.stream.name in ['<stderr>', '<stdout>']]
|
||||
console_handlers = [
|
||||
h
|
||||
for h in root_logger.handlers
|
||||
if isinstance(h, logging.StreamHandler)
|
||||
and h.stream.name in ["<stderr>", "<stdout>"]
|
||||
]
|
||||
|
||||
if not console_handlers:
|
||||
console_handler = logging.StreamHandler()
|
||||
@@ -157,7 +169,9 @@ def configure_file_logging(logger_name: str = None) -> bool:
|
||||
bool: True if file logging was configured, False if skipped (stateless mode)
|
||||
"""
|
||||
# Check if stateless mode is enabled
|
||||
stateless_mode = os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
|
||||
stateless_mode = (
|
||||
os.getenv("WORKSPACE_MCP_STATELESS_MODE", "false").lower() == "true"
|
||||
)
|
||||
|
||||
if stateless_mode:
|
||||
logger = logging.getLogger(logger_name)
|
||||
@@ -170,14 +184,14 @@ def configure_file_logging(logger_name: str = None) -> bool:
|
||||
log_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# Go up one level since we're in core/ subdirectory
|
||||
log_file_dir = os.path.dirname(log_file_dir)
|
||||
log_file_path = os.path.join(log_file_dir, 'mcp_server_debug.log')
|
||||
log_file_path = os.path.join(log_file_dir, "mcp_server_debug.log")
|
||||
|
||||
file_handler = logging.FileHandler(log_file_path, mode='a')
|
||||
file_handler = logging.FileHandler(log_file_path, mode="a")
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
|
||||
file_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(process)d - %(threadName)s '
|
||||
'[%(module)s.%(funcName)s:%(lineno)d] - %(message)s'
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(process)d - %(threadName)s "
|
||||
"[%(module)s.%(funcName)s:%(lineno)d] - %(message)s"
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
target_logger.addHandler(file_handler)
|
||||
@@ -187,5 +201,7 @@ def configure_file_logging(logger_name: str = None) -> bool:
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
sys.stderr.write(f"CRITICAL: Failed to set up file logging to '{log_file_path}': {e}\n")
|
||||
sys.stderr.write(
|
||||
f"CRITICAL: Failed to set up file logging to '{log_file_path}': {e}\n"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -13,9 +13,13 @@ from fastmcp.server.auth.providers.google import GoogleProvider
|
||||
from auth.oauth21_session_store import get_oauth21_session_store, set_auth_provider
|
||||
from auth.google_auth import handle_auth_callback, start_auth_flow, check_client_secrets
|
||||
from auth.mcp_session_middleware import MCPSessionMiddleware
|
||||
from auth.oauth_responses import create_error_response, create_success_response, create_server_error_response
|
||||
from auth.oauth_responses import (
|
||||
create_error_response,
|
||||
create_success_response,
|
||||
create_server_error_response,
|
||||
)
|
||||
from auth.auth_info_middleware import AuthInfoMiddleware
|
||||
from auth.scopes import SCOPES, get_current_scopes # noqa
|
||||
from auth.scopes import SCOPES, get_current_scopes # noqa
|
||||
from core.config import (
|
||||
USER_GOOGLE_EMAIL,
|
||||
get_transport_mode,
|
||||
@@ -31,6 +35,7 @@ _legacy_callback_registered = False
|
||||
|
||||
session_middleware = Middleware(MCPSessionMiddleware)
|
||||
|
||||
|
||||
# Custom FastMCP that adds secure middleware stack for OAuth 2.1
|
||||
class SecureFastMCP(FastMCP):
|
||||
def streamable_http_app(self) -> "Starlette":
|
||||
@@ -46,6 +51,7 @@ class SecureFastMCP(FastMCP):
|
||||
logger.info("Added middleware stack: Session Management")
|
||||
return app
|
||||
|
||||
|
||||
server = SecureFastMCP(
|
||||
name="google_workspace",
|
||||
auth=None,
|
||||
@@ -69,6 +75,7 @@ def _ensure_legacy_callback_route() -> None:
|
||||
server.custom_route("/oauth2callback", methods=["GET"])(legacy_oauth2_callback)
|
||||
_legacy_callback_registered = True
|
||||
|
||||
|
||||
def configure_server_for_http():
|
||||
"""
|
||||
Configures the authentication provider for HTTP transport.
|
||||
@@ -83,6 +90,7 @@ def configure_server_for_http():
|
||||
|
||||
# Use centralized OAuth configuration
|
||||
from auth.oauth_config import get_oauth_config
|
||||
|
||||
config = get_oauth_config()
|
||||
|
||||
# Check if OAuth 2.1 is enabled via centralized config
|
||||
@@ -110,8 +118,12 @@ def configure_server_for_http():
|
||||
)
|
||||
# Disable protocol-level auth, expect bearer tokens in tool calls
|
||||
server.auth = None
|
||||
logger.info("OAuth 2.1 enabled with EXTERNAL provider mode - protocol-level auth disabled")
|
||||
logger.info("Expecting Authorization bearer tokens in tool call headers")
|
||||
logger.info(
|
||||
"OAuth 2.1 enabled with EXTERNAL provider mode - protocol-level auth disabled"
|
||||
)
|
||||
logger.info(
|
||||
"Expecting Authorization bearer tokens in tool call headers"
|
||||
)
|
||||
else:
|
||||
# Standard OAuth 2.1 mode: use FastMCP's GoogleProvider
|
||||
provider = GoogleProvider(
|
||||
@@ -123,13 +135,17 @@ def configure_server_for_http():
|
||||
)
|
||||
# Enable protocol-level auth
|
||||
server.auth = provider
|
||||
logger.info("OAuth 2.1 enabled using FastMCP GoogleProvider with protocol-level auth")
|
||||
logger.info(
|
||||
"OAuth 2.1 enabled using FastMCP GoogleProvider with protocol-level auth"
|
||||
)
|
||||
|
||||
# Always set auth provider for token validation in middleware
|
||||
set_auth_provider(provider)
|
||||
_auth_provider = provider
|
||||
except Exception as exc:
|
||||
logger.error("Failed to initialize FastMCP GoogleProvider: %s", exc, exc_info=True)
|
||||
logger.error(
|
||||
"Failed to initialize FastMCP GoogleProvider: %s", exc, exc_info=True
|
||||
)
|
||||
raise
|
||||
else:
|
||||
logger.info("OAuth 2.0 mode - Server will use legacy authentication.")
|
||||
@@ -143,53 +159,56 @@ def get_auth_provider() -> Optional[GoogleProvider]:
|
||||
"""Gets the global authentication provider instance."""
|
||||
return _auth_provider
|
||||
|
||||
|
||||
@server.custom_route("/health", methods=["GET"])
|
||||
async def health_check(request: Request):
|
||||
try:
|
||||
version = metadata.version("workspace-mcp")
|
||||
except metadata.PackageNotFoundError:
|
||||
version = "dev"
|
||||
return JSONResponse({
|
||||
"status": "healthy",
|
||||
"service": "workspace-mcp",
|
||||
"version": version,
|
||||
"transport": get_transport_mode()
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"status": "healthy",
|
||||
"service": "workspace-mcp",
|
||||
"version": version,
|
||||
"transport": get_transport_mode(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@server.custom_route("/attachments/{file_id}", methods=["GET"])
|
||||
async def serve_attachment(file_id: str, request: Request):
|
||||
"""Serve a stored attachment file."""
|
||||
from core.attachment_storage import get_attachment_storage
|
||||
|
||||
|
||||
storage = get_attachment_storage()
|
||||
metadata = storage.get_attachment_metadata(file_id)
|
||||
|
||||
|
||||
if not metadata:
|
||||
return JSONResponse(
|
||||
{"error": "Attachment not found or expired"},
|
||||
status_code=404
|
||||
{"error": "Attachment not found or expired"}, status_code=404
|
||||
)
|
||||
|
||||
|
||||
file_path = storage.get_attachment_path(file_id)
|
||||
if not file_path:
|
||||
return JSONResponse(
|
||||
{"error": "Attachment file not found"},
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return JSONResponse({"error": "Attachment file not found"}, status_code=404)
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
filename=metadata["filename"],
|
||||
media_type=metadata["mime_type"]
|
||||
media_type=metadata["mime_type"],
|
||||
)
|
||||
|
||||
|
||||
async def legacy_oauth2_callback(request: Request) -> HTMLResponse:
|
||||
state = request.query_params.get("state")
|
||||
code = request.query_params.get("code")
|
||||
error = request.query_params.get("error")
|
||||
|
||||
if error:
|
||||
msg = f"Authentication failed: Google returned an error: {error}. State: {state}."
|
||||
msg = (
|
||||
f"Authentication failed: Google returned an error: {error}. State: {state}."
|
||||
)
|
||||
logger.error(msg)
|
||||
return create_error_response(msg)
|
||||
|
||||
@@ -206,17 +225,19 @@ async def legacy_oauth2_callback(request: Request) -> HTMLResponse:
|
||||
logger.info(f"OAuth callback: Received code (state: {state}).")
|
||||
|
||||
mcp_session_id = None
|
||||
if hasattr(request, 'state') and hasattr(request.state, 'session_id'):
|
||||
if hasattr(request, "state") and hasattr(request.state, "session_id"):
|
||||
mcp_session_id = request.state.session_id
|
||||
|
||||
verified_user_id, credentials = handle_auth_callback(
|
||||
scopes=get_current_scopes(),
|
||||
authorization_response=str(request.url),
|
||||
redirect_uri=get_oauth_redirect_uri_for_current_mode(),
|
||||
session_id=mcp_session_id
|
||||
session_id=mcp_session_id,
|
||||
)
|
||||
|
||||
logger.info(f"OAuth callback: Successfully authenticated user: {verified_user_id}.")
|
||||
logger.info(
|
||||
f"OAuth callback: Successfully authenticated user: {verified_user_id}."
|
||||
)
|
||||
|
||||
try:
|
||||
store = get_oauth21_session_store()
|
||||
@@ -233,7 +254,9 @@ async def legacy_oauth2_callback(request: Request) -> HTMLResponse:
|
||||
session_id=f"google-{state}",
|
||||
mcp_session_id=mcp_session_id,
|
||||
)
|
||||
logger.info(f"Stored Google credentials in OAuth 2.1 session store for {verified_user_id}")
|
||||
logger.info(
|
||||
f"Stored Google credentials in OAuth 2.1 session store for {verified_user_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store credentials in OAuth 2.1 store: {e}")
|
||||
|
||||
@@ -242,8 +265,11 @@ async def legacy_oauth2_callback(request: Request) -> HTMLResponse:
|
||||
logger.error(f"Error processing OAuth callback: {str(e)}", exc_info=True)
|
||||
return create_server_error_response(str(e))
|
||||
|
||||
|
||||
@server.tool()
|
||||
async def start_google_auth(service_name: str, user_google_email: str = USER_GOOGLE_EMAIL) -> str:
|
||||
async def start_google_auth(
|
||||
service_name: str, user_google_email: str = USER_GOOGLE_EMAIL
|
||||
) -> str:
|
||||
"""
|
||||
Manually initiate Google OAuth authentication flow.
|
||||
|
||||
@@ -268,7 +294,7 @@ async def start_google_auth(service_name: str, user_google_email: str = USER_GOO
|
||||
auth_message = await start_auth_flow(
|
||||
user_google_email=user_google_email,
|
||||
service_name=service_name,
|
||||
redirect_uri=get_oauth_redirect_uri_for_current_mode()
|
||||
redirect_uri=get_oauth_redirect_uri_for_current_mode(),
|
||||
)
|
||||
return auth_message
|
||||
except Exception as e:
|
||||
|
||||
@@ -13,32 +13,37 @@ logger = logging.getLogger(__name__)
|
||||
# Global registry of enabled tools
|
||||
_enabled_tools: Optional[Set[str]] = None
|
||||
|
||||
|
||||
def set_enabled_tools(tool_names: Optional[Set[str]]):
|
||||
"""Set the globally enabled tools."""
|
||||
global _enabled_tools
|
||||
_enabled_tools = tool_names
|
||||
|
||||
|
||||
def get_enabled_tools() -> Optional[Set[str]]:
|
||||
"""Get the set of enabled tools, or None if all tools are enabled."""
|
||||
return _enabled_tools
|
||||
|
||||
|
||||
def is_tool_enabled(tool_name: str) -> bool:
|
||||
"""Check if a specific tool is enabled."""
|
||||
if _enabled_tools is None:
|
||||
return True # All tools enabled by default
|
||||
return tool_name in _enabled_tools
|
||||
|
||||
|
||||
def conditional_tool(server, tool_name: str):
|
||||
"""
|
||||
Decorator that conditionally registers a tool based on the enabled tools set.
|
||||
|
||||
|
||||
Args:
|
||||
server: The FastMCP server instance
|
||||
tool_name: The name of the tool to register
|
||||
|
||||
|
||||
Returns:
|
||||
Either the registered tool decorator or a no-op decorator
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
if is_tool_enabled(tool_name):
|
||||
logger.debug(f"Registering tool: {tool_name}")
|
||||
@@ -46,51 +51,55 @@ def conditional_tool(server, tool_name: str):
|
||||
else:
|
||||
logger.debug(f"Skipping tool registration: {tool_name}")
|
||||
return func
|
||||
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def wrap_server_tool_method(server):
|
||||
"""
|
||||
Track tool registrations and filter them post-registration.
|
||||
"""
|
||||
original_tool = server.tool
|
||||
server._tracked_tools = []
|
||||
|
||||
|
||||
def tracking_tool(*args, **kwargs):
|
||||
original_decorator = original_tool(*args, **kwargs)
|
||||
|
||||
|
||||
def wrapper_decorator(func: Callable) -> Callable:
|
||||
tool_name = func.__name__
|
||||
server._tracked_tools.append(tool_name)
|
||||
# Always apply the original decorator to register the tool
|
||||
return original_decorator(func)
|
||||
|
||||
|
||||
return wrapper_decorator
|
||||
|
||||
|
||||
server.tool = tracking_tool
|
||||
|
||||
|
||||
def filter_server_tools(server):
|
||||
"""Remove disabled tools from the server after registration."""
|
||||
enabled_tools = get_enabled_tools()
|
||||
if enabled_tools is None:
|
||||
return
|
||||
|
||||
|
||||
tools_removed = 0
|
||||
|
||||
|
||||
# Access FastMCP's tool registry via _tool_manager._tools
|
||||
if hasattr(server, '_tool_manager'):
|
||||
if hasattr(server, "_tool_manager"):
|
||||
tool_manager = server._tool_manager
|
||||
if hasattr(tool_manager, '_tools'):
|
||||
if hasattr(tool_manager, "_tools"):
|
||||
tool_registry = tool_manager._tools
|
||||
|
||||
|
||||
tools_to_remove = []
|
||||
for tool_name in list(tool_registry.keys()):
|
||||
if not is_tool_enabled(tool_name):
|
||||
tools_to_remove.append(tool_name)
|
||||
|
||||
|
||||
for tool_name in tools_to_remove:
|
||||
del tool_registry[tool_name]
|
||||
tools_removed += 1
|
||||
|
||||
|
||||
if tools_removed > 0:
|
||||
logger.info(f"Tool tier filtering: removed {tools_removed} tools, {len(enabled_tools)} enabled")
|
||||
logger.info(
|
||||
f"Tool tier filtering: removed {tools_removed} tools, {len(enabled_tools)} enabled"
|
||||
)
|
||||
|
||||
@@ -15,33 +15,36 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
TierLevel = Literal["core", "extended", "complete"]
|
||||
|
||||
|
||||
class ToolTierLoader:
|
||||
"""Loads and manages tool tiers from configuration."""
|
||||
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""
|
||||
Initialize the tool tier loader.
|
||||
|
||||
|
||||
Args:
|
||||
config_path: Path to the tool_tiers.yaml file. If None, uses default location.
|
||||
"""
|
||||
if config_path is None:
|
||||
# Default to core/tool_tiers.yaml relative to this file
|
||||
config_path = Path(__file__).parent / "tool_tiers.yaml"
|
||||
|
||||
|
||||
self.config_path = Path(config_path)
|
||||
self._tiers_config: Optional[Dict] = None
|
||||
|
||||
|
||||
def _load_config(self) -> Dict:
|
||||
"""Load the tool tiers configuration from YAML file."""
|
||||
if self._tiers_config is not None:
|
||||
return self._tiers_config
|
||||
|
||||
|
||||
if not self.config_path.exists():
|
||||
raise FileNotFoundError(f"Tool tiers configuration not found: {self.config_path}")
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Tool tiers configuration not found: {self.config_path}"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
self._tiers_config = yaml.safe_load(f)
|
||||
logger.info(f"Loaded tool tiers configuration from {self.config_path}")
|
||||
return self._tiers_config
|
||||
@@ -49,65 +52,71 @@ class ToolTierLoader:
|
||||
raise ValueError(f"Invalid YAML in tool tiers configuration: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load tool tiers configuration: {e}")
|
||||
|
||||
|
||||
def get_available_services(self) -> List[str]:
|
||||
"""Get list of all available services defined in the configuration."""
|
||||
config = self._load_config()
|
||||
return list(config.keys())
|
||||
|
||||
def get_tools_for_tier(self, tier: TierLevel, services: Optional[List[str]] = None) -> List[str]:
|
||||
|
||||
def get_tools_for_tier(
|
||||
self, tier: TierLevel, services: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get all tools for a specific tier level.
|
||||
|
||||
|
||||
Args:
|
||||
tier: The tier level (core, extended, complete)
|
||||
services: Optional list of services to filter by. If None, includes all services.
|
||||
|
||||
|
||||
Returns:
|
||||
List of tool names for the specified tier level
|
||||
"""
|
||||
config = self._load_config()
|
||||
tools = []
|
||||
|
||||
|
||||
# If no services specified, use all available services
|
||||
if services is None:
|
||||
services = self.get_available_services()
|
||||
|
||||
|
||||
for service in services:
|
||||
if service not in config:
|
||||
logger.warning(f"Service '{service}' not found in tool tiers configuration")
|
||||
logger.warning(
|
||||
f"Service '{service}' not found in tool tiers configuration"
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
service_config = config[service]
|
||||
if tier not in service_config:
|
||||
logger.debug(f"Tier '{tier}' not defined for service '{service}'")
|
||||
continue
|
||||
|
||||
|
||||
tier_tools = service_config[tier]
|
||||
if tier_tools: # Handle empty lists
|
||||
tools.extend(tier_tools)
|
||||
|
||||
|
||||
return tools
|
||||
|
||||
def get_tools_up_to_tier(self, tier: TierLevel, services: Optional[List[str]] = None) -> List[str]:
|
||||
|
||||
def get_tools_up_to_tier(
|
||||
self, tier: TierLevel, services: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get all tools up to and including the specified tier level.
|
||||
|
||||
|
||||
Args:
|
||||
tier: The maximum tier level to include
|
||||
services: Optional list of services to filter by. If None, includes all services.
|
||||
|
||||
|
||||
Returns:
|
||||
List of tool names up to the specified tier level
|
||||
"""
|
||||
tier_order = ["core", "extended", "complete"]
|
||||
max_tier_index = tier_order.index(tier)
|
||||
|
||||
|
||||
tools = []
|
||||
for i in range(max_tier_index + 1):
|
||||
current_tier = tier_order[i]
|
||||
tools.extend(self.get_tools_for_tier(current_tier, services))
|
||||
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_tools = []
|
||||
@@ -115,39 +124,41 @@ class ToolTierLoader:
|
||||
if tool not in seen:
|
||||
seen.add(tool)
|
||||
unique_tools.append(tool)
|
||||
|
||||
|
||||
return unique_tools
|
||||
|
||||
|
||||
def get_services_for_tools(self, tool_names: List[str]) -> Set[str]:
|
||||
"""
|
||||
Get the service names that provide the specified tools.
|
||||
|
||||
|
||||
Args:
|
||||
tool_names: List of tool names to lookup
|
||||
|
||||
|
||||
Returns:
|
||||
Set of service names that provide any of the specified tools
|
||||
"""
|
||||
config = self._load_config()
|
||||
services = set()
|
||||
|
||||
|
||||
for service, service_config in config.items():
|
||||
for tier_name, tier_tools in service_config.items():
|
||||
if tier_tools and any(tool in tier_tools for tool in tool_names):
|
||||
services.add(service)
|
||||
break
|
||||
|
||||
|
||||
return services
|
||||
|
||||
|
||||
def get_tools_for_tier(tier: TierLevel, services: Optional[List[str]] = None) -> List[str]:
|
||||
def get_tools_for_tier(
|
||||
tier: TierLevel, services: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Convenience function to get tools for a specific tier.
|
||||
|
||||
|
||||
Args:
|
||||
tier: The tier level (core, extended, complete)
|
||||
services: Optional list of services to filter by
|
||||
|
||||
|
||||
Returns:
|
||||
List of tool names for the specified tier level
|
||||
"""
|
||||
@@ -155,27 +166,31 @@ def get_tools_for_tier(tier: TierLevel, services: Optional[List[str]] = None) ->
|
||||
return loader.get_tools_up_to_tier(tier, services)
|
||||
|
||||
|
||||
def resolve_tools_from_tier(tier: TierLevel, services: Optional[List[str]] = None) -> tuple[List[str], List[str]]:
|
||||
def resolve_tools_from_tier(
|
||||
tier: TierLevel, services: Optional[List[str]] = None
|
||||
) -> tuple[List[str], List[str]]:
|
||||
"""
|
||||
Resolve tool names and service names for the specified tier.
|
||||
|
||||
|
||||
Args:
|
||||
tier: The tier level (core, extended, complete)
|
||||
services: Optional list of services to filter by
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (tool_names, service_names) where:
|
||||
- tool_names: List of specific tool names for the tier
|
||||
- service_names: List of service names that should be imported
|
||||
"""
|
||||
loader = ToolTierLoader()
|
||||
|
||||
|
||||
# Get all tools for the tier
|
||||
tools = loader.get_tools_up_to_tier(tier, services)
|
||||
|
||||
|
||||
# Map back to service names
|
||||
service_names = loader.get_services_for_tools(tools)
|
||||
|
||||
logger.info(f"Tier '{tier}' resolved to {len(tools)} tools across {len(service_names)} services: {sorted(service_names)}")
|
||||
|
||||
return tools, sorted(service_names)
|
||||
|
||||
logger.info(
|
||||
f"Tier '{tier}' resolved to {len(tools)} tools across {len(service_names)} services: {sorted(service_names)}"
|
||||
)
|
||||
|
||||
return tools, sorted(service_names)
|
||||
|
||||
@@ -27,6 +27,7 @@ class UserInputError(Exception):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def check_credentials_directory_permissions(credentials_dir: str = None) -> None:
|
||||
"""
|
||||
Check if the service has appropriate permissions to create and write to the .credentials directory.
|
||||
@@ -181,7 +182,7 @@ def extract_office_xml_text(file_bytes: bytes, mime_type: str) -> Optional[str]:
|
||||
member_texts.append(shared_strings[ss_idx])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Invalid shared string index {ss_idx} in {member}. Max index: {len(shared_strings)-1}"
|
||||
f"Invalid shared string index {ss_idx} in {member}. Max index: {len(shared_strings) - 1}"
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
@@ -240,7 +241,9 @@ def extract_office_xml_text(file_bytes: bytes, mime_type: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def handle_http_errors(tool_name: str, is_read_only: bool = False, service_type: Optional[str] = None):
|
||||
def handle_http_errors(
|
||||
tool_name: str, is_read_only: bool = False, service_type: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
A decorator to handle Google API HttpErrors and transient SSL errors in a standardized way.
|
||||
|
||||
@@ -288,11 +291,16 @@ def handle_http_errors(tool_name: str, is_read_only: bool = False, service_type:
|
||||
except HttpError as error:
|
||||
user_google_email = kwargs.get("user_google_email", "N/A")
|
||||
error_details = str(error)
|
||||
|
||||
|
||||
# Check if this is an API not enabled error
|
||||
if error.resp.status == 403 and "accessNotConfigured" in error_details:
|
||||
enablement_msg = get_api_enablement_message(error_details, service_type)
|
||||
|
||||
if (
|
||||
error.resp.status == 403
|
||||
and "accessNotConfigured" in error_details
|
||||
):
|
||||
enablement_msg = get_api_enablement_message(
|
||||
error_details, service_type
|
||||
)
|
||||
|
||||
if enablement_msg:
|
||||
message = (
|
||||
f"API error in {tool_name}: {enablement_msg}\n\n"
|
||||
@@ -314,7 +322,7 @@ def handle_http_errors(tool_name: str, is_read_only: bool = False, service_type:
|
||||
else:
|
||||
# Other HTTP errors (400 Bad Request, etc.) - don't suggest re-auth
|
||||
message = f"API error in {tool_name}: {error}"
|
||||
|
||||
|
||||
logger.error(f"API error in {tool_name}: {error}", exc_info=True)
|
||||
raise Exception(message) from error
|
||||
except TransientNetworkError:
|
||||
|
||||
Reference in New Issue
Block a user