unify authentication logic

This commit is contained in:
Taylor Wilsdon
2025-05-13 12:36:53 -04:00
parent 4e13196802
commit 0bebaee051
5 changed files with 320 additions and 450 deletions

View File

@@ -4,6 +4,7 @@ import os
import json
import logging
from typing import List, Optional, Tuple, Dict, Any, Callable
import os # Ensure os is imported
from oauthlib.oauth2.rfc6749.errors import InsecureTransportError
@@ -13,18 +14,35 @@ from google.auth.transport.requests import Request
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
# Import necessary components from core.server (OAUTH_REDIRECT_URI is no longer imported here)
# from core.server import OAUTH_REDIRECT_URI
# Import shared configuration from the new config file
from config.google_config import OAUTH_STATE_TO_SESSION_ID_MAP, SCOPES
from mcp import types # Ensure mcp.types is available
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
DEFAULT_CREDENTIALS_DIR = ".credentials"
DEFAULT_REDIRECT_URI = "http://localhost:8000/oauth2callback"
# DEFAULT_REDIRECT_URI was previously imported, now passed as parameter
# In-memory cache for session credentials
# Maps session_id to Credentials object
# This should be a more robust cache in a production system (e.g., Redis)
_SESSION_CREDENTIALS_CACHE: Dict[str, Credentials] = {}
# Centralized Client Secrets Path Logic
_client_secrets_env = os.getenv("GOOGLE_CLIENT_SECRETS")
if _client_secrets_env:
CONFIG_CLIENT_SECRETS_PATH = _client_secrets_env
else:
# Assumes this file is in auth/ and client_secret.json is in the root
CONFIG_CLIENT_SECRETS_PATH = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
'client_secret.json'
)
# --- Helper Functions ---
@@ -112,58 +130,90 @@ def load_client_secrets(client_secrets_path: str) -> Dict[str, Any]:
# --- Core OAuth Logic ---
def start_auth_flow(
client_secrets_path: str,
scopes: List[str],
redirect_uri: str = DEFAULT_REDIRECT_URI
) -> Tuple[str, str]:
async def start_auth_flow(
mcp_session_id: Optional[str],
user_google_email: Optional[str],
service_name: str, # e.g., "Google Calendar", "Gmail" for user messages
redirect_uri: str, # Added redirect_uri as a required parameter
) -> types.CallToolResult:
"""
Initiates the OAuth 2.0 flow and returns the authorization URL and state.
Initiates the Google OAuth flow and returns an actionable message for the user.
Args:
client_secrets_path: Path to the Google client secrets JSON file.
scopes: List of OAuth scopes required.
mcp_session_id: The active MCP session ID.
user_google_email: The user's specified Google email, if provided.
service_name: The name of the Google service requiring auth (for user messages).
redirect_uri: The URI Google will redirect to after authorization.
Returns:
A tuple containing the authorization URL and the state parameter.
A CallToolResult with isError=True containing guidance for the LLM/user.
"""
initial_email_provided = bool(user_google_email and user_google_email.strip() and user_google_email.lower() != 'default')
user_display_name = f"{service_name} for '{user_google_email}'" if initial_email_provided else service_name
logger.info(f"[start_auth_flow] Initiating auth for {user_display_name} (session: {mcp_session_id}) with global SCOPES.")
try:
# Allow HTTP for localhost in development
if 'OAUTHLIB_INSECURE_TRANSPORT' not in os.environ:
if 'OAUTHLIB_INSECURE_TRANSPORT' not in os.environ and "localhost" in redirect_uri: # Use passed redirect_uri
logger.warning("OAUTHLIB_INSECURE_TRANSPORT not set. Setting it for localhost development.")
os.environ['OAUTHLIB_INSECURE_TRANSPORT'] = '1'
# Set up the OAuth flow
oauth_state = os.urandom(16).hex()
if mcp_session_id:
OAUTH_STATE_TO_SESSION_ID_MAP[oauth_state] = mcp_session_id
logger.info(f"[start_auth_flow] Stored mcp_session_id '{mcp_session_id}' for oauth_state '{oauth_state}'.")
flow = Flow.from_client_secrets_file(
client_secrets_path,
scopes=scopes,
redirect_uri=redirect_uri
CONFIG_CLIENT_SECRETS_PATH, # Use module constant
scopes=SCOPES, # Use global SCOPES
redirect_uri=redirect_uri, # Use passed redirect_uri
state=oauth_state
)
# Indicate that the user needs *offline* access to retrieve a refresh token.
# 'prompt': 'consent' ensures the user sees the consent screen even if
# they have previously granted permissions, which is useful for getting
# a refresh token again if needed.
authorization_url, state = flow.authorization_url(
access_type='offline',
prompt='consent'
)
logger.info(f"Generated authorization URL. State: {state}")
return authorization_url, state
auth_url, _ = flow.authorization_url(access_type='offline', prompt='consent')
logger.info(f"Auth flow started for {user_display_name}. State: {oauth_state}. Advise user to visit: {auth_url}")
message_lines = [
f"**ACTION REQUIRED: Google Authentication Needed for {user_display_name}**\n",
f"To proceed, the user must authorize this application for {service_name} access using all required permissions.",
"**LLM, please present this exact authorization URL to the user as a clickable hyperlink:**",
f"Authorization URL: {auth_url}",
f"Markdown for hyperlink: [Click here to authorize {service_name} access]({auth_url})\n",
"**LLM, after presenting the link, instruct the user as follows:**",
"1. Click the link and complete the authorization in their browser.",
]
session_info_for_llm = f" (this will link to your current session {mcp_session_id})" if mcp_session_id else ""
if not initial_email_provided:
message_lines.extend([
f"2. After successful authorization{session_info_for_llm}, the browser page will display the authenticated email address.",
" **LLM: Instruct the user to provide you with this email address.**",
"3. Once you have the email, **retry their original command, ensuring you include this `user_google_email`.**"
])
else:
message_lines.append(f"2. After successful authorization{session_info_for_llm}, **retry their original command**.")
message_lines.append(f"\nThe application will use the new credentials. If '{user_google_email}' was provided, it must match the authenticated account.")
message = "\n".join(message_lines)
return types.CallToolResult(
isError=True,
content=[types.TextContent(type="text", text=message)]
)
except FileNotFoundError as e:
error_text = f"OAuth client secrets file not found: {e}. Please ensure '{CONFIG_CLIENT_SECRETS_PATH}' is correctly configured."
logger.error(error_text, exc_info=True)
return types.CallToolResult(isError=True, content=[types.TextContent(type="text", text=error_text)])
except Exception as e:
logger.error(f"Error starting OAuth flow: {e}")
# We no longer shut down the server after completing the flow
# The persistent server will handle multiple auth flows over time
raise # Re-raise the exception for the caller to handle
error_text = f"Could not initiate authentication for {user_display_name} due to an unexpected error: {str(e)}"
logger.error(f"Failed to start the OAuth flow for {user_display_name}: {e}", exc_info=True)
return types.CallToolResult(isError=True, content=[types.TextContent(type="text", text=error_text)])
def handle_auth_callback(
client_secrets_path: str,
scopes: List[str],
authorization_response: str,
redirect_uri: str = DEFAULT_REDIRECT_URI,
redirect_uri: str, # Made redirect_uri a required parameter
credentials_base_dir: str = DEFAULT_CREDENTIALS_DIR,
session_id: Optional[str] = None
) -> Tuple[str, Credentials]:
@@ -270,13 +320,13 @@ def get_credentials(
if not credentials:
logger.info(f"[get_credentials] No credentials found for user '{user_google_email}' or session '{session_id}'.")
return None
logger.info(f"[get_credentials] Credentials found. Scopes: {credentials.scopes}, Valid: {credentials.valid}, Expired: {credentials.expired}")
if not all(scope in credentials.scopes for scope in required_scopes):
logger.warning(f"[get_credentials] Credentials lack required scopes. Need: {required_scopes}, Have: {credentials.scopes}. User: '{user_google_email}', Session: '{session_id}'")
return None # Re-authentication needed for scopes
logger.info(f"[get_credentials] Credentials have sufficient scopes. User: '{user_google_email}', Session: '{session_id}'")
if credentials.valid:
@@ -292,7 +342,7 @@ def get_credentials(
# client_config = load_client_secrets(client_secrets_path) # Not strictly needed if creds have client_id/secret
credentials.refresh(Request())
logger.info(f"[get_credentials] Credentials refreshed successfully. User: '{user_google_email}', Session: '{session_id}'")
# Save refreshed credentials
if user_google_email: # Always save to file if email is known
save_credentials_to_file(user_google_email, credentials, credentials_base_dir)
@@ -337,41 +387,4 @@ if __name__ == '__main__':
# --- Flow Initiation Example ---
# In a real app, this URL would be presented to the user.
# try:
# auth_url, state = start_auth_flow(_CLIENT_SECRETS_FILE, _SCOPES)
# print(f"Please go to this URL and authorize: {auth_url}")
# print(f"State parameter: {state}") # State needs to be stored/verified in callback
# # The application would then wait for the callback...
# except Exception as e:
# print(f"Error starting flow: {e}")
# --- Callback Handling Example ---
# This would be triggered by the redirect from Google.
# callback_url = input("Paste the full callback URL here: ")
# try:
# user_id, creds = handle_auth_callback(_CLIENT_SECRETS_FILE, _SCOPES, callback_url)
# print(f"Authentication successful for user: {user_id}")
# print(f"Credentials obtained: {creds.token[:10]}...") # Print snippet
# except Exception as e:
# print(f"Error handling callback: {e}")
# --- Credential Retrieval Example ---
# This would happen when the application needs to access a Google API.
# print(f"\nAttempting to retrieve credentials for user: {_TEST_USER_ID}")
# try:
# retrieved_creds = get_credentials(_TEST_USER_ID, _SCOPES, _CLIENT_SECRETS_FILE)
# if retrieved_creds and retrieved_creds.valid:
# print(f"Successfully retrieved valid credentials for {_TEST_USER_ID}.")
# # Example: Use credentials to get user info again
# user_data = get_user_info(retrieved_creds)
# print(f"User Info: {user_data}")
# elif retrieved_creds:
# print(f"Retrieved credentials for {_TEST_USER_ID}, but they are not valid (maybe expired and couldn't refresh?).")
# else:
# print(f"Could not retrieve valid credentials for {_TEST_USER_ID}. Re-authentication needed.")
# except Exception as e:
# print(f"Error retrieving credentials: {e}")
pass # Keep the example block commented out or remove for production
# try: