refactor authentication to dedupe
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple, Dict, Any, Callable
|
||||
import os
|
||||
|
||||
@@ -436,6 +437,97 @@ def get_user_info(credentials: Credentials) -> Optional[Dict[str, Any]]:
|
||||
logger.error(f"Unexpected error fetching user info: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# --- Centralized Google Service Authentication ---
|
||||
|
||||
async def get_authenticated_google_service(
|
||||
service_name: str, # "gmail", "calendar", "drive", "docs"
|
||||
version: str, # "v1", "v3"
|
||||
tool_name: str, # For logging/debugging
|
||||
user_google_email: str, # Required - no more Optional
|
||||
required_scopes: List[str],
|
||||
) -> tuple[Any, str] | types.CallToolResult:
|
||||
"""
|
||||
Centralized Google service authentication for all MCP tools.
|
||||
Returns (service, user_email) on success or CallToolResult on failure.
|
||||
|
||||
Args:
|
||||
service_name: The Google service name ("gmail", "calendar", "drive", "docs")
|
||||
version: The API version ("v1", "v3", etc.)
|
||||
tool_name: The name of the calling tool (for logging/debugging)
|
||||
user_google_email: The user's Google email address (required)
|
||||
required_scopes: List of required OAuth scopes
|
||||
|
||||
Returns:
|
||||
tuple[service, user_email] on success, or CallToolResult on auth failure
|
||||
"""
|
||||
logger.info(
|
||||
f"[{tool_name}] Attempting to get authenticated {service_name} service. Email: '{user_google_email}'"
|
||||
)
|
||||
|
||||
# Validate email format
|
||||
if not user_google_email or "@" not in user_google_email:
|
||||
error_msg = f"Authentication required for {tool_name}. No valid 'user_google_email' provided. Please provide a valid Google email address."
|
||||
logger.info(f"[{tool_name}] {error_msg}")
|
||||
return types.CallToolResult(
|
||||
isError=True, content=[types.TextContent(type="text", text=error_msg)]
|
||||
)
|
||||
|
||||
credentials = await asyncio.to_thread(
|
||||
get_credentials,
|
||||
user_google_email=user_google_email,
|
||||
required_scopes=required_scopes,
|
||||
client_secrets_path=CONFIG_CLIENT_SECRETS_PATH,
|
||||
session_id=None, # No longer using session-based auth
|
||||
)
|
||||
|
||||
if not credentials or not credentials.valid:
|
||||
logger.warning(
|
||||
f"[{tool_name}] No valid credentials. Email: '{user_google_email}'."
|
||||
)
|
||||
logger.info(
|
||||
f"[{tool_name}] Valid email '{user_google_email}' provided, initiating auth flow."
|
||||
)
|
||||
|
||||
# Import here to avoid circular import
|
||||
from config.google_config import OAUTH_REDIRECT_URI
|
||||
|
||||
# This call will return a CallToolResult which should be propagated
|
||||
return await start_auth_flow(
|
||||
mcp_session_id=None, # No longer using session-based auth
|
||||
user_google_email=user_google_email,
|
||||
service_name=f"Google {service_name.title()}",
|
||||
redirect_uri=OAUTH_REDIRECT_URI,
|
||||
)
|
||||
|
||||
try:
|
||||
service = build(service_name, version, credentials=credentials)
|
||||
log_user_email = user_google_email
|
||||
|
||||
# Try to get email from credentials if needed for validation
|
||||
if credentials and credentials.id_token:
|
||||
try:
|
||||
import jwt
|
||||
# Decode without verification (just to get email for logging)
|
||||
decoded_token = jwt.decode(credentials.id_token, options={"verify_signature": False})
|
||||
token_email = decoded_token.get("email")
|
||||
if token_email:
|
||||
log_user_email = token_email
|
||||
logger.info(f"[{tool_name}] Token email: {token_email}")
|
||||
except Exception as e:
|
||||
logger.debug(f"[{tool_name}] Could not decode id_token: {e}")
|
||||
|
||||
logger.info(f"[{tool_name}] Successfully authenticated {service_name} service for user: {log_user_email}")
|
||||
return service, log_user_email
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"[{tool_name}] Failed to build {service_name} service: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return types.CallToolResult(
|
||||
isError=True, content=[types.TextContent(type="text", text=error_msg)]
|
||||
)
|
||||
|
||||
|
||||
# Example Usage (Illustrative - not meant to be run directly without context)
|
||||
if __name__ == '__main__':
|
||||
# This block is for demonstration/testing purposes only.
|
||||
|
||||
Reference in New Issue
Block a user