refactor authentication to dedupe

This commit is contained in:
Taylor Wilsdon
2025-05-24 10:43:55 -04:00
parent ceaa019c93
commit 9e4add5ac2
5 changed files with 311 additions and 542 deletions

View File

@@ -16,66 +16,49 @@ from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseDownload # For file content
import io # For file content
# Use functions directly from google_auth
from auth.google_auth import get_credentials, start_auth_flow, CONFIG_CLIENT_SECRETS_PATH
from core.server import server, OAUTH_REDIRECT_URI, OAUTH_STATE_TO_SESSION_ID_MAP
from auth.google_auth import get_authenticated_google_service
from core.server import server
from core.server import (
DRIVE_READONLY_SCOPE,
DRIVE_FILE_SCOPE,
SCOPES
)
logger = logging.getLogger(__name__)
@server.tool()
async def search_drive_files(
user_google_email: str,
query: str,
user_google_email: Optional[str] = None,
page_size: int = 10,
mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id")
) -> types.CallToolResult:
"""
Searches for files and folders within a user's Google Drive based on a query string.
Prioritizes authentication via the active MCP session (`mcp_session_id`).
If the session isn't authenticated for Drive, it falls back to using `user_google_email`.
If neither provides valid credentials, it returns a message guiding the LLM to request the user's email
or initiate the authentication flow via the centralized start_auth_flow.
Args:
user_google_email (str): The user's Google email address. Required.
query (str): The search query string. Supports Google Drive search operators (e.g., 'name contains "report"', 'mimeType="application/vnd.google-apps.document"', 'parents in "folderId"').
user_google_email (Optional[str]): The user's Google email address. Required if the MCP session is not already authenticated for Drive access.
page_size (int): The maximum number of files to return. Defaults to 10.
mcp_session_id (Optional[str]): The active MCP session ID (automatically injected by FastMCP from the Mcp-Session-Id header). Used for session-based authentication.
Returns:
types.CallToolResult: Contains a list of found files/folders with their details (ID, name, type, size, modified time, link),
an error message if the API call fails,
or an authentication guidance message if credentials are required.
"""
logger.info(f"[search_drive_files] Invoked. Session: '{mcp_session_id}', Email: '{user_google_email}', Query: '{query}'")
tool_specific_scopes = [DRIVE_READONLY_SCOPE]
credentials = await asyncio.to_thread(
get_credentials,
user_google_email=user_google_email,
required_scopes=tool_specific_scopes,
client_secrets_path=CONFIG_CLIENT_SECRETS_PATH, # Use imported constant
session_id=mcp_session_id
)
tool_name = "search_drive_files"
logger.info(f"[{tool_name}] Invoked. Email: '{user_google_email}', Query: '{query}'")
if not credentials or not credentials.valid:
logger.warning(f"[search_drive_files] No valid credentials for Drive. Session: '{mcp_session_id}', Email: '{user_google_email}'.")
if user_google_email and '@' in user_google_email:
# Use the centralized start_auth_flow
return await start_auth_flow(mcp_session_id=mcp_session_id, user_google_email=user_google_email, service_name="Google Drive", redirect_uri=OAUTH_REDIRECT_URI)
else:
error_msg = "Drive Authentication required. LLM: Please ask for Google email."
return types.CallToolResult(isError=True, content=[types.TextContent(type="text", text=error_msg)])
auth_result = await get_authenticated_google_service(
service_name="drive",
version="v3",
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=[DRIVE_READONLY_SCOPE],
)
if isinstance(auth_result, types.CallToolResult):
return auth_result # Auth error
service, user_email = auth_result
try:
service = build('drive', 'v3', credentials=credentials)
user_email_from_creds = 'Unknown (Drive)'
if credentials.id_token and isinstance(credentials.id_token, dict):
user_email_from_creds = credentials.id_token.get('email', 'Unknown (Drive)')
# Check if the query looks like a structured Drive query or free text
# Basic check for operators or common keywords used in structured queries
@@ -118,49 +101,37 @@ async def search_drive_files(
@server.tool()
async def get_drive_file_content(
user_google_email: str,
file_id: str,
user_google_email: Optional[str] = None,
mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id")
) -> types.CallToolResult:
"""
Retrieves the content of a specific file from Google Drive by its ID.
Handles both native Google Docs/Sheets/Slides (exporting them to plain text or CSV) and other file types (downloading directly).
Prioritizes authentication via the active MCP session (`mcp_session_id`).
If the session isn't authenticated for Drive, it falls back to using `user_google_email`.
If neither provides valid credentials, it returns a message guiding the LLM to request the user's email
or initiate the authentication flow via the centralized start_auth_flow.
Args:
user_google_email (str): The user's Google email address. Required.
file_id (str): The unique ID of the Google Drive file to retrieve content from. This ID is typically obtained from `search_drive_files` or `list_drive_items`.
user_google_email (Optional[str]): The user's Google email address. Required if the MCP session is not already authenticated for Drive access.
mcp_session_id (Optional[str]): The active MCP session ID (automatically injected by FastMCP from the Mcp-Session-Id header). Used for session-based authentication.
Returns:
types.CallToolResult: Contains the file metadata (name, ID, type, link) and its content (decoded as UTF-8 if possible, otherwise indicates binary content),
an error message if the API call fails or the file is not accessible/found,
or an authentication guidance message if credentials are required.
"""
logger.info(f"[get_drive_file_content] Invoked. File ID: '{file_id}'")
tool_specific_scopes = [DRIVE_READONLY_SCOPE]
credentials = await asyncio.to_thread(
get_credentials,
user_google_email=user_google_email,
required_scopes=tool_specific_scopes,
client_secrets_path=CONFIG_CLIENT_SECRETS_PATH, # Use imported constant
session_id=mcp_session_id
)
tool_name = "get_drive_file_content"
logger.info(f"[{tool_name}] Invoked. File ID: '{file_id}'")
if not credentials or not credentials.valid:
logger.warning(f"[get_drive_file_content] No valid credentials for Drive. Session: '{mcp_session_id}', Email: '{user_google_email}'.")
if user_google_email and '@' in user_google_email:
# Use the centralized start_auth_flow
return await start_auth_flow(mcp_session_id=mcp_session_id, user_google_email=user_google_email, service_name="Google Drive", redirect_uri=OAUTH_REDIRECT_URI)
else:
error_msg = "Drive Authentication required. LLM: Please ask for Google email."
return types.CallToolResult(isError=True, content=[types.TextContent(type="text", text=error_msg)])
auth_result = await get_authenticated_google_service(
service_name="drive",
version="v3",
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=[DRIVE_READONLY_SCOPE],
)
if isinstance(auth_result, types.CallToolResult):
return auth_result # Auth error
service, user_email = auth_result
try:
service = build('drive', 'v3', credentials=credentials)
file_metadata = await asyncio.to_thread(
service.files().get(fileId=file_id, fields="id, name, mimeType, webViewLink").execute
)
@@ -200,55 +171,39 @@ async def get_drive_file_content(
@server.tool()
async def list_drive_items(
user_google_email: str,
folder_id: str = 'root', # Default to root folder
user_google_email: Optional[str] = None,
page_size: int = 100, # Default page size for listing
mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id")
) -> types.CallToolResult:
"""
Lists files and folders directly within a specified Google Drive folder.
Defaults to the root folder if `folder_id` is not provided. Does not recurse into subfolders.
Prioritizes authentication via the active MCP session (`mcp_session_id`).
If the session isn't authenticated for Drive, it falls back to using `user_google_email`.
If neither provides valid credentials, it returns a message guiding the LLM to request the user's email
or initiate the authentication flow via the centralized start_auth_flow.
Args:
user_google_email (str): The user's Google email address. Required.
folder_id (str): The ID of the Google Drive folder to list items from. Defaults to 'root'.
user_google_email (Optional[str]): The user's Google email address. Required if the MCP session is not already authenticated for Drive access.
page_size (int): The maximum number of items to return per page. Defaults to 100.
mcp_session_id (Optional[str]): The active MCP session ID (automatically injected by FastMCP from the Mcp-Session-Id header). Used for session-based authentication.
Returns:
types.CallToolResult: Contains a list of files/folders within the specified folder, including their details (ID, name, type, size, modified time, link),
an error message if the API call fails or the folder is not accessible/found,
or an authentication guidance message if credentials are required.
"""
logger.info(f"[list_drive_items] Invoked. Session: '{mcp_session_id}', Email: '{user_google_email}', Folder ID: '{folder_id}'")
tool_specific_scopes = [DRIVE_READONLY_SCOPE]
credentials = await asyncio.to_thread(
get_credentials,
user_google_email=user_google_email,
required_scopes=tool_specific_scopes,
client_secrets_path=CONFIG_CLIENT_SECRETS_PATH, # Use imported constant
session_id=mcp_session_id
)
tool_name = "list_drive_items"
logger.info(f"[{tool_name}] Invoked. Email: '{user_google_email}', Folder ID: '{folder_id}'")
if not credentials or not credentials.valid:
logger.warning(f"[list_drive_items] No valid credentials for Drive. Session: '{mcp_session_id}', Email: '{user_google_email}'.")
if user_google_email and '@' in user_google_email:
# Use the centralized start_auth_flow
return await start_auth_flow(mcp_session_id=mcp_session_id, user_google_email=user_google_email, service_name="Google Drive", redirect_uri=OAUTH_REDIRECT_URI)
else:
error_msg = "Drive Authentication required. LLM: Please ask for Google email."
return types.CallToolResult(isError=True, content=[types.TextContent(type="text", text=error_msg)])
auth_result = await get_authenticated_google_service(
service_name="drive",
version="v3",
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=[DRIVE_READONLY_SCOPE],
)
if isinstance(auth_result, types.CallToolResult):
return auth_result # Auth error
service, user_email = auth_result
try:
service = build('drive', 'v3', credentials=credentials)
user_email_from_creds = 'Unknown (Drive)'
if credentials.id_token and isinstance(credentials.id_token, dict):
user_email_from_creds = credentials.id_token.get('email', 'Unknown (Drive)')
results = await asyncio.to_thread(
service.files().list(
q=f"'{folder_id}' in parents and trashed=false", # List items directly in the folder
@@ -277,54 +232,40 @@ async def list_drive_items(
@server.tool()
async def create_drive_file(
user_google_email: str,
file_name: str,
content: str,
folder_id: str = 'root', # Default to root folder
user_google_email: Optional[str] = None,
mime_type: str = 'text/plain', # Default to plain text
mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id")
) -> types.CallToolResult:
"""
Creates a new file in Google Drive with the specified name, content, and optional parent folder.
Prioritizes authenticated MCP session, then `user_google_email`.
If no valid authentication is found, guides the LLM to obtain user's email or use `start_auth`.
Args:
user_google_email (str): The user's Google email address. Required.
file_name (str): The name for the new file.
content (str): The content to write to the file.
folder_id (str): The ID of the parent folder. Defaults to 'root'.
user_google_email (Optional[str]): User's Google email. Used if session isn't authenticated.
mime_type (str): The MIME type of the file. Defaults to 'text/plain'.
mcp_session_id (Optional[str]): Active MCP session ID (injected by FastMCP from Mcp-Session-Id header).
Returns:
A CallToolResult confirming creation or an error/auth guidance message.
"""
logger.info(f"[create_drive_file] Invoked. Session: '{mcp_session_id}', Email: '{user_google_email}', File Name: {file_name}, Folder ID: {folder_id}")
tool_specific_scopes = [DRIVE_FILE_SCOPE] # Use DRIVE_FILE_SCOPE for creating files
credentials = await asyncio.to_thread(
get_credentials,
user_google_email=user_google_email,
required_scopes=tool_specific_scopes,
client_secrets_path=CONFIG_CLIENT_SECRETS_PATH, # Use imported constant
session_id=mcp_session_id
)
tool_name = "create_drive_file"
logger.info(f"[{tool_name}] Invoked. Email: '{user_google_email}', File Name: {file_name}, Folder ID: {folder_id}")
if not credentials or not credentials.valid:
logger.warning(f"[create_drive_file] No valid credentials. Session: '{mcp_session_id}', Email: '{user_google_email}'.")
if user_google_email and '@' in user_google_email:
# Use the centralized start_auth_flow
return await start_auth_flow(mcp_session_id=mcp_session_id, user_google_email=user_google_email, service_name="Google Drive", redirect_uri=OAUTH_REDIRECT_URI)
else:
error_msg = "Authentication required to create file. LLM: Please ask for Google email."
return types.CallToolResult(isError=True, content=[types.TextContent(type="text", text=error_msg)])
auth_result = await get_authenticated_google_service(
service_name="drive",
version="v3",
tool_name=tool_name,
user_google_email=user_google_email,
required_scopes=[DRIVE_FILE_SCOPE],
)
if isinstance(auth_result, types.CallToolResult):
return auth_result # Auth error
service, user_email = auth_result
try:
service = build('drive', 'v3', credentials=credentials)
user_email_from_creds = 'Unknown (Drive)'
if credentials.id_token and isinstance(credentials.id_token, dict):
user_email_from_creds = credentials.id_token.get('email', 'Unknown (Drive)')
file_metadata = {
'name': file_name,
'parents': [folder_id],
@@ -341,9 +282,10 @@ async def create_drive_file(
)
link = created_file.get('webViewLink', 'No link available')
confirmation_message = f"Successfully created file '{created_file.get('name', file_name)}' (ID: {created_file.get('id', 'N/A')}) in folder '{folder_id}' for {user_email_from_creds}. Link: {link}"
confirmation_message = f"Successfully created file '{created_file.get('name', file_name)}' (ID: {created_file.get('id', 'N/A')}) in folder '{folder_id}' for {user_email}. Link: {link}"
logger.info(f"Successfully created file. Link: {link}")
return types.CallToolResult(content=[types.TextContent(type="text", text=confirmation_message)])
except HttpError as error:
logger.error(f"API error creating Drive file '{file_name}': {error}", exc_info=True)
return types.CallToolResult(isError=True, content=[types.TextContent(type="text", text=f"API error: {error}")])