refactor authentication to dedupe

This commit is contained in:
Taylor Wilsdon
2025-05-24 10:43:55 -04:00
parent ceaa019c93
commit 9e4add5ac2
5 changed files with 311 additions and 542 deletions

View File

@@ -3,6 +3,7 @@
import os
import json
import logging
import asyncio
from typing import List, Optional, Tuple, Dict, Any, Callable
import os
@@ -436,6 +437,97 @@ def get_user_info(credentials: Credentials) -> Optional[Dict[str, Any]]:
logger.error(f"Unexpected error fetching user info: {e}")
return None
# --- Centralized Google Service Authentication ---
async def get_authenticated_google_service(
service_name: str, # "gmail", "calendar", "drive", "docs"
version: str, # "v1", "v3"
tool_name: str, # For logging/debugging
user_google_email: str, # Required - no more Optional
required_scopes: List[str],
) -> tuple[Any, str] | types.CallToolResult:
"""
Centralized Google service authentication for all MCP tools.
Returns (service, user_email) on success or CallToolResult on failure.
Args:
service_name: The Google service name ("gmail", "calendar", "drive", "docs")
version: The API version ("v1", "v3", etc.)
tool_name: The name of the calling tool (for logging/debugging)
user_google_email: The user's Google email address (required)
required_scopes: List of required OAuth scopes
Returns:
tuple[service, user_email] on success, or CallToolResult on auth failure
"""
logger.info(
f"[{tool_name}] Attempting to get authenticated {service_name} service. Email: '{user_google_email}'"
)
# Validate email format
if not user_google_email or "@" not in user_google_email:
error_msg = f"Authentication required for {tool_name}. No valid 'user_google_email' provided. Please provide a valid Google email address."
logger.info(f"[{tool_name}] {error_msg}")
return types.CallToolResult(
isError=True, content=[types.TextContent(type="text", text=error_msg)]
)
credentials = await asyncio.to_thread(
get_credentials,
user_google_email=user_google_email,
required_scopes=required_scopes,
client_secrets_path=CONFIG_CLIENT_SECRETS_PATH,
session_id=None, # No longer using session-based auth
)
if not credentials or not credentials.valid:
logger.warning(
f"[{tool_name}] No valid credentials. Email: '{user_google_email}'."
)
logger.info(
f"[{tool_name}] Valid email '{user_google_email}' provided, initiating auth flow."
)
# Import here to avoid circular import
from config.google_config import OAUTH_REDIRECT_URI
# This call will return a CallToolResult which should be propagated
return await start_auth_flow(
mcp_session_id=None, # No longer using session-based auth
user_google_email=user_google_email,
service_name=f"Google {service_name.title()}",
redirect_uri=OAUTH_REDIRECT_URI,
)
try:
service = build(service_name, version, credentials=credentials)
log_user_email = user_google_email
# Try to get email from credentials if needed for validation
if credentials and credentials.id_token:
try:
import jwt
# Decode without verification (just to get email for logging)
decoded_token = jwt.decode(credentials.id_token, options={"verify_signature": False})
token_email = decoded_token.get("email")
if token_email:
log_user_email = token_email
logger.info(f"[{tool_name}] Token email: {token_email}")
except Exception as e:
logger.debug(f"[{tool_name}] Could not decode id_token: {e}")
logger.info(f"[{tool_name}] Successfully authenticated {service_name} service for user: {log_user_email}")
return service, log_user_email
except Exception as e:
error_msg = f"[{tool_name}] Failed to build {service_name} service: {str(e)}"
logger.error(error_msg, exc_info=True)
return types.CallToolResult(
isError=True, content=[types.TextContent(type="text", text=error_msg)]
)
# Example Usage (Illustrative - not meant to be run directly without context)
if __name__ == '__main__':
# This block is for demonstration/testing purposes only.