Merge branch 'main' of github.com:taylorwilsdon/google_workspace_mcp into feature/create-drive-folder

This commit is contained in:
Taylor Wilsdon
2026-02-19 09:55:07 -05:00
59 changed files with 5738 additions and 1234 deletions

View File

@@ -11,10 +11,11 @@ import httpx
import base64
import ipaddress
import socket
from contextlib import asynccontextmanager
from typing import Optional, List, Dict, Any
from typing import AsyncIterator, Optional, List, Dict, Any
from tempfile import NamedTemporaryFile
from urllib.parse import urlparse
from urllib.parse import urljoin, urlparse, urlunparse
from urllib.request import url2pathname
from pathlib import Path
@@ -24,7 +25,7 @@ from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload
from auth.service_decorator import require_google_service
from auth.oauth_config import is_stateless_mode
from core.attachment_storage import get_attachment_storage, get_attachment_url
from core.utils import extract_office_xml_text, handle_http_errors
from core.utils import extract_office_xml_text, handle_http_errors, validate_file_path
from core.server import server
from core.config import get_transport_mode
from gdrive.drive_helpers import (
@@ -44,6 +45,7 @@ logger = logging.getLogger(__name__)
DOWNLOAD_CHUNK_SIZE_BYTES = 256 * 1024 # 256 KB
UPLOAD_CHUNK_SIZE_BYTES = 5 * 1024 * 1024 # 5 MB (Google recommended minimum)
MAX_DOWNLOAD_BYTES = 2 * 1024 * 1024 * 1024 # 2 GB safety limit for URL downloads
@server.tool()
@@ -223,25 +225,28 @@ async def get_drive_file_download_url(
export_format: Optional[str] = None,
) -> str:
"""
Gets a download URL for a Google Drive file. The file is prepared and made available via HTTP URL.
Downloads a Google Drive file and saves it to local disk.
In stdio mode, returns the local file path for direct access.
In HTTP mode, returns a temporary download URL (valid for 1 hour).
For Google native files (Docs, Sheets, Slides), exports to a useful format:
Google Docs PDF (default) or DOCX if export_format='docx'
Google Sheets XLSX (default), PDF if export_format='pdf', or CSV if export_format='csv'
Google Slides PDF (default) or PPTX if export_format='pptx'
- Google Docs -> PDF (default) or DOCX if export_format='docx'
- Google Sheets -> XLSX (default), PDF if export_format='pdf', or CSV if export_format='csv'
- Google Slides -> PDF (default) or PPTX if export_format='pptx'
For other files, downloads the original file format.
Args:
user_google_email: The user's Google email address. Required.
file_id: The Google Drive file ID to get a download URL for.
file_id: The Google Drive file ID to download.
export_format: Optional export format for Google native files.
Options: 'pdf', 'docx', 'xlsx', 'csv', 'pptx'.
If not specified, uses sensible defaults (PDF for Docs/Slides, XLSX for Sheets).
For Sheets: supports 'csv', 'pdf', or 'xlsx' (default).
Returns:
str: Download URL and file metadata. The file is available at the URL for 1 hour.
str: File metadata with either a local file path or download URL.
"""
logger.info(
f"[get_drive_file_download_url] Invoked. File ID: '{file_id}', Export format: {export_format}"
@@ -346,41 +351,45 @@ async def get_drive_file_download_url(
)
return "\n".join(result_lines)
# Save file and generate URL
# Save file to local disk and return file path
try:
storage = get_attachment_storage()
# Encode bytes to base64 (as expected by AttachmentStorage)
base64_data = base64.urlsafe_b64encode(file_content_bytes).decode("utf-8")
# Save attachment
saved_file_id = storage.save_attachment(
# Save attachment to local disk
result = storage.save_attachment(
base64_data=base64_data,
filename=output_filename,
mime_type=output_mime_type,
)
# Generate URL
download_url = get_attachment_url(saved_file_id)
result_lines = [
"File downloaded successfully!",
f"File: {file_name}",
f"File ID: {file_id}",
f"Size: {size_kb:.1f} KB ({size_bytes} bytes)",
f"MIME Type: {output_mime_type}",
f"\n📎 Download URL: {download_url}",
"\nThe file has been saved and is available at the URL above.",
"The file will expire after 1 hour.",
]
if get_transport_mode() == "stdio":
result_lines.append(f"\n📎 Saved to: {result.path}")
result_lines.append(
"\nThe file has been saved to disk and can be accessed directly via the file path."
)
else:
download_url = get_attachment_url(result.file_id)
result_lines.append(f"\n📎 Download URL: {download_url}")
result_lines.append("\nThe file will expire after 1 hour.")
if export_mime_type:
result_lines.append(
f"\nNote: Google native file exported to {output_mime_type} format."
)
logger.info(
f"[get_drive_file_download_url] Successfully saved {size_kb:.1f} KB file as {saved_file_id}"
f"[get_drive_file_download_url] Successfully saved {size_kb:.1f} KB file to {result.path}"
)
return "\n".join(result_lines)
@@ -604,8 +613,8 @@ async def create_drive_file(
raw_path = f"//{netloc}{raw_path}"
file_path = url2pathname(raw_path)
# Verify file exists
path_obj = Path(file_path)
# Validate path safety and verify file exists
path_obj = validate_file_path(file_path)
if not path_obj.exists():
extra = (
" The server is running via streamable-http, so file:// URLs must point to files inside the container or remote host."
@@ -650,21 +659,20 @@ async def create_drive_file(
elif parsed_url.scheme in ("http", "https"):
# when running in stateless mode, deployment may not have access to local file system
if is_stateless_mode():
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.get(fileUrl)
if resp.status_code != 200:
raise Exception(
f"Failed to fetch file from URL: {fileUrl} (status {resp.status_code})"
)
file_data = await resp.aread()
# Try to get MIME type from Content-Type header
content_type = resp.headers.get("Content-Type")
if content_type and content_type != "application/octet-stream":
mime_type = content_type
file_metadata["mimeType"] = content_type
logger.info(
f"[create_drive_file] Using MIME type from Content-Type header: {content_type}"
)
resp = await _ssrf_safe_fetch(fileUrl)
if resp.status_code != 200:
raise Exception(
f"Failed to fetch file from URL: {fileUrl} (status {resp.status_code})"
)
file_data = resp.content
# Try to get MIME type from Content-Type header
content_type = resp.headers.get("Content-Type")
if content_type and content_type != "application/octet-stream":
mime_type = content_type
file_metadata["mimeType"] = content_type
logger.info(
f"[create_drive_file] Using MIME type from Content-Type header: {content_type}"
)
media = MediaIoBaseUpload(
io.BytesIO(file_data),
@@ -684,44 +692,46 @@ async def create_drive_file(
.execute
)
else:
# Use NamedTemporaryFile to stream download and upload
# Stream download to temp file with SSRF protection, then upload
with NamedTemporaryFile() as temp_file:
total_bytes = 0
# follow redirects
async with httpx.AsyncClient(follow_redirects=True) as client:
async with client.stream("GET", fileUrl) as resp:
if resp.status_code != 200:
raise Exception(
f"Failed to fetch file from URL: {fileUrl} (status {resp.status_code})"
)
content_type = None
# Stream download in chunks
async for chunk in resp.aiter_bytes(
chunk_size=DOWNLOAD_CHUNK_SIZE_BYTES
):
await asyncio.to_thread(temp_file.write, chunk)
total_bytes += len(chunk)
logger.info(
f"[create_drive_file] Downloaded {total_bytes} bytes from URL before upload."
async with _ssrf_safe_stream(fileUrl) as resp:
if resp.status_code != 200:
raise Exception(
f"Failed to fetch file from URL: {fileUrl} "
f"(status {resp.status_code})"
)
# Try to get MIME type from Content-Type header
content_type = resp.headers.get("Content-Type")
if (
content_type
and content_type != "application/octet-stream"
):
mime_type = content_type
file_metadata["mimeType"] = mime_type
logger.info(
f"[create_drive_file] Using MIME type from Content-Type header: {mime_type}"
content_type = resp.headers.get("Content-Type")
async for chunk in resp.aiter_bytes(
chunk_size=DOWNLOAD_CHUNK_SIZE_BYTES
):
total_bytes += len(chunk)
if total_bytes > MAX_DOWNLOAD_BYTES:
raise Exception(
f"Download exceeded {MAX_DOWNLOAD_BYTES} byte limit"
)
await asyncio.to_thread(temp_file.write, chunk)
logger.info(
f"[create_drive_file] Downloaded {total_bytes} bytes "
f"from URL before upload."
)
if content_type and content_type != "application/octet-stream":
mime_type = content_type
file_metadata["mimeType"] = mime_type
logger.info(
f"[create_drive_file] Using MIME type from "
f"Content-Type header: {mime_type}"
)
# Reset file pointer to beginning for upload
temp_file.seek(0)
# Upload with chunking
media = MediaIoBaseUpload(
temp_file,
mimetype=mime_type,
@@ -788,16 +798,18 @@ GOOGLE_DOCS_IMPORT_FORMATS = {
GOOGLE_DOCS_MIME_TYPE = "application/vnd.google-apps.document"
def _validate_url_not_internal(url: str) -> None:
def _resolve_and_validate_host(hostname: str) -> list[str]:
"""
Validate that a URL doesn't point to internal/private networks (SSRF protection).
Resolve a hostname to IP addresses and validate none are private/internal.
Uses getaddrinfo to handle both IPv4 and IPv6. Fails closed on DNS errors.
Returns:
list[str]: Validated resolved IP address strings.
Raises:
ValueError: If URL points to localhost or private IP ranges
ValueError: If hostname resolves to private/internal IPs or DNS fails.
"""
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
raise ValueError("Invalid URL: no hostname")
@@ -805,15 +817,266 @@ def _validate_url_not_internal(url: str) -> None:
if hostname.lower() in ("localhost", "127.0.0.1", "::1", "0.0.0.0"):
raise ValueError("URLs pointing to localhost are not allowed")
# Resolve hostname and check if it's a private IP
# Resolve hostname using getaddrinfo (handles both IPv4 and IPv6)
try:
ip = ipaddress.ip_address(socket.gethostbyname(hostname))
if ip.is_private or ip.is_loopback or ip.is_reserved:
addr_infos = socket.getaddrinfo(hostname, None)
except socket.gaierror as e:
raise ValueError(
f"Cannot resolve hostname '{hostname}': {e}. "
"Refusing request (fail-closed)."
)
if not addr_infos:
raise ValueError(f"No addresses found for hostname: {hostname}")
resolved_ips: list[str] = []
seen_ips: set[str] = set()
for _family, _type, _proto, _canonname, sockaddr in addr_infos:
ip_str = sockaddr[0]
ip = ipaddress.ip_address(ip_str)
if not ip.is_global:
raise ValueError(
f"URLs pointing to private/internal networks are not allowed: {hostname}"
f"URLs pointing to private/internal networks are not allowed: "
f"{hostname} resolves to {ip_str}"
)
except socket.gaierror:
pass # Can't resolve, let httpx handle it
if ip_str not in seen_ips:
seen_ips.add(ip_str)
resolved_ips.append(ip_str)
return resolved_ips
def _validate_url_not_internal(url: str) -> list[str]:
"""
Validate that a URL doesn't point to internal/private networks (SSRF protection).
Returns:
list[str]: Validated resolved IP addresses for the hostname.
Raises:
ValueError: If URL points to localhost or private IP ranges.
"""
parsed = urlparse(url)
return _resolve_and_validate_host(parsed.hostname)
def _format_host_header(hostname: str, scheme: str, port: Optional[int]) -> str:
"""Format the Host header value for IPv4/IPv6 hostnames."""
host_value = hostname
if ":" in host_value and not host_value.startswith("["):
host_value = f"[{host_value}]"
is_default_port = (scheme == "http" and (port is None or port == 80)) or (
scheme == "https" and (port is None or port == 443)
)
if not is_default_port and port is not None:
host_value = f"{host_value}:{port}"
return host_value
def _build_pinned_url(parsed_url, ip_address_str: str) -> str:
"""Build a URL that targets a resolved IP while preserving path/query."""
pinned_host = ip_address_str
if ":" in pinned_host and not pinned_host.startswith("["):
pinned_host = f"[{pinned_host}]"
userinfo = ""
if parsed_url.username is not None:
userinfo = parsed_url.username
if parsed_url.password is not None:
userinfo += f":{parsed_url.password}"
userinfo += "@"
port_part = f":{parsed_url.port}" if parsed_url.port is not None else ""
netloc = f"{userinfo}{pinned_host}{port_part}"
path = parsed_url.path or "/"
return urlunparse(
(
parsed_url.scheme,
netloc,
path,
parsed_url.params,
parsed_url.query,
parsed_url.fragment,
)
)
async def _fetch_url_with_pinned_ip(url: str) -> httpx.Response:
"""
Fetch URL content by connecting to a validated, pre-resolved IP address.
This prevents DNS rebinding between validation and the outbound connection.
"""
parsed_url = urlparse(url)
if parsed_url.scheme not in ("http", "https"):
raise ValueError(f"Only http:// and https:// are supported: {url}")
if not parsed_url.hostname:
raise ValueError(f"Invalid URL: missing hostname ({url})")
resolved_ips = _validate_url_not_internal(url)
host_header = _format_host_header(
parsed_url.hostname, parsed_url.scheme, parsed_url.port
)
last_error: Optional[Exception] = None
for resolved_ip in resolved_ips:
pinned_url = _build_pinned_url(parsed_url, resolved_ip)
try:
async with httpx.AsyncClient(
follow_redirects=False, trust_env=False
) as client:
request = client.build_request(
"GET",
pinned_url,
headers={"Host": host_header},
extensions={"sni_hostname": parsed_url.hostname},
)
return await client.send(request)
except httpx.HTTPError as exc:
last_error = exc
logger.warning(
f"[ssrf_safe_fetch] Failed request via resolved IP {resolved_ip} for host "
f"{parsed_url.hostname}: {exc}"
)
raise Exception(
f"Failed to fetch URL after trying {len(resolved_ips)} validated IP(s): {url}"
) from last_error
async def _ssrf_safe_fetch(url: str, *, stream: bool = False) -> httpx.Response:
"""
Fetch a URL with SSRF protection that covers redirects and DNS rebinding.
Validates the initial URL and every redirect target against private/internal
networks. Disables automatic redirect following and handles redirects manually.
Args:
url: The URL to fetch.
stream: If True, returns a streaming response (caller must manage context).
Returns:
httpx.Response with the final response content.
Raises:
ValueError: If any URL in the redirect chain points to a private network.
Exception: If the HTTP request fails.
"""
if stream:
raise ValueError("Streaming mode is not supported by _ssrf_safe_fetch.")
max_redirects = 10
current_url = url
for _ in range(max_redirects):
resp = await _fetch_url_with_pinned_ip(current_url)
if resp.status_code in (301, 302, 303, 307, 308):
location = resp.headers.get("location")
if not location:
raise Exception(f"Redirect with no Location header from {current_url}")
# Resolve relative redirects against the current URL
location = urljoin(current_url, location)
redirect_parsed = urlparse(location)
if redirect_parsed.scheme not in ("http", "https"):
raise ValueError(
f"Redirect to disallowed scheme: {redirect_parsed.scheme}"
)
current_url = location
continue
return resp
raise Exception(f"Too many redirects (max {max_redirects}) fetching {url}")
@asynccontextmanager
async def _ssrf_safe_stream(url: str) -> AsyncIterator[httpx.Response]:
"""
SSRF-safe streaming fetch: validates each redirect target against private
networks, then streams the final response body without buffering it all
in memory.
Usage::
async with _ssrf_safe_stream(file_url) as resp:
async for chunk in resp.aiter_bytes(chunk_size=DOWNLOAD_CHUNK_SIZE_BYTES):
...
"""
max_redirects = 10
current_url = url
# Resolve redirects manually so every hop is SSRF-validated
for _ in range(max_redirects):
parsed = urlparse(current_url)
if parsed.scheme not in ("http", "https"):
raise ValueError(f"Only http:// and https:// are supported: {current_url}")
if not parsed.hostname:
raise ValueError(f"Invalid URL: missing hostname ({current_url})")
resolved_ips = _validate_url_not_internal(current_url)
host_header = _format_host_header(parsed.hostname, parsed.scheme, parsed.port)
last_error: Optional[Exception] = None
resp: Optional[httpx.Response] = None
for resolved_ip in resolved_ips:
pinned_url = _build_pinned_url(parsed, resolved_ip)
client = httpx.AsyncClient(follow_redirects=False, trust_env=False)
try:
request = client.build_request(
"GET",
pinned_url,
headers={"Host": host_header},
extensions={"sni_hostname": parsed.hostname},
)
resp = await client.send(request, stream=True)
break
except httpx.HTTPError as exc:
last_error = exc
await client.aclose()
logger.warning(
f"[ssrf_safe_stream] Failed via IP {resolved_ip} for "
f"{parsed.hostname}: {exc}"
)
except Exception:
await client.aclose()
raise
if resp is None:
raise Exception(
f"Failed to fetch URL after trying {len(resolved_ips)} validated IP(s): "
f"{current_url}"
) from last_error
if resp.status_code in (301, 302, 303, 307, 308):
location = resp.headers.get("location")
await resp.aclose()
await client.aclose()
if not location:
raise Exception(f"Redirect with no Location header from {current_url}")
location = urljoin(current_url, location)
redirect_parsed = urlparse(location)
if redirect_parsed.scheme not in ("http", "https"):
raise ValueError(
f"Redirect to disallowed scheme: {redirect_parsed.scheme}"
)
current_url = location
continue
# Non-redirect — yield the streaming response
try:
yield resp
finally:
await resp.aclose()
await client.aclose()
return
raise Exception(f"Too many redirects (max {max_redirects}) fetching {url}")
def _detect_source_format(file_name: str, content: Optional[str] = None) -> str:
@@ -945,7 +1208,7 @@ async def import_to_google_doc(
f"file_path should be a local path or file:// URL, got: {file_path}"
)
path_obj = Path(actual_path)
path_obj = validate_file_path(actual_path)
if not path_obj.exists():
raise FileNotFoundError(f"File not found: {actual_path}")
if not path_obj.is_file():
@@ -967,16 +1230,13 @@ async def import_to_google_doc(
if parsed_url.scheme not in ("http", "https"):
raise ValueError(f"file_url must be http:// or https://, got: {file_url}")
# SSRF protection: block internal/private network URLs
_validate_url_not_internal(file_url)
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.get(file_url)
if resp.status_code != 200:
raise Exception(
f"Failed to fetch file from URL: {file_url} (status {resp.status_code})"
)
file_data = resp.content
# SSRF protection: block internal/private network URLs and validate redirects
resp = await _ssrf_safe_fetch(file_url)
if resp.status_code != 200:
raise Exception(
f"Failed to fetch file from URL: {file_url} (status {resp.status_code})"
)
file_data = resp.content
logger.info(
f"[import_to_google_doc] Downloaded from URL: {len(file_data)} bytes"
@@ -1985,3 +2245,174 @@ async def transfer_drive_ownership(
output_parts.extend(["", "Note: Previous owner now has editor access."])
return "\n".join(output_parts)
@server.tool()
@handle_http_errors(
"set_drive_file_permissions", is_read_only=False, service_type="drive"
)
@require_google_service("drive", "drive_file")
async def set_drive_file_permissions(
service,
user_google_email: str,
file_id: str,
link_sharing: Optional[str] = None,
writers_can_share: Optional[bool] = None,
copy_requires_writer_permission: Optional[bool] = None,
) -> str:
"""
Sets file-level sharing settings and controls link sharing for a Google Drive file or folder.
This is a high-level tool for the most common permission changes. Use this to toggle
"anyone with the link" access or configure file-level sharing behavior. For managing
individual user/group permissions, use share_drive_file or update_drive_permission instead.
Args:
user_google_email (str): The user's Google email address. Required.
file_id (str): The ID of the file or folder. Required.
link_sharing (Optional[str]): Control "anyone with the link" access for the file.
- "off": Disable "anyone with the link" access for this file.
- "reader": Anyone with the link can view.
- "commenter": Anyone with the link can comment.
- "writer": Anyone with the link can edit.
writers_can_share (Optional[bool]): Whether editors can change permissions and share.
If False, only the owner can share. Defaults to None (no change).
copy_requires_writer_permission (Optional[bool]): Whether viewers and commenters
are prevented from copying, printing, or downloading. Defaults to None (no change).
Returns:
str: Summary of all permission changes applied to the file.
"""
logger.info(
f"[set_drive_file_permissions] Invoked. Email: '{user_google_email}', "
f"File ID: '{file_id}', Link sharing: '{link_sharing}', "
f"Writers can share: {writers_can_share}, Copy restriction: {copy_requires_writer_permission}"
)
if (
link_sharing is None
and writers_can_share is None
and copy_requires_writer_permission is None
):
raise ValueError(
"Must provide at least one of: link_sharing, writers_can_share, copy_requires_writer_permission"
)
valid_link_sharing = {"off", "reader", "commenter", "writer"}
if link_sharing is not None and link_sharing not in valid_link_sharing:
raise ValueError(
f"Invalid link_sharing '{link_sharing}'. Must be one of: {', '.join(sorted(valid_link_sharing))}"
)
resolved_file_id, file_metadata = await resolve_drive_item(
service, file_id, extra_fields="name, webViewLink"
)
file_id = resolved_file_id
file_name = file_metadata.get("name", "Unknown")
output_parts = [f"Permission settings updated for '{file_name}'", ""]
changes_made = []
# Handle file-level settings via files().update()
file_update_body = {}
if writers_can_share is not None:
file_update_body["writersCanShare"] = writers_can_share
if copy_requires_writer_permission is not None:
file_update_body["copyRequiresWriterPermission"] = (
copy_requires_writer_permission
)
if file_update_body:
await asyncio.to_thread(
service.files()
.update(
fileId=file_id,
body=file_update_body,
supportsAllDrives=True,
fields="id",
)
.execute
)
if writers_can_share is not None:
state = "allowed" if writers_can_share else "restricted to owner"
changes_made.append(f" - Editors sharing: {state}")
if copy_requires_writer_permission is not None:
state = "restricted" if copy_requires_writer_permission else "allowed"
changes_made.append(f" - Viewers copy/print/download: {state}")
# Handle link sharing via permissions API
if link_sharing is not None:
current_permissions = await asyncio.to_thread(
service.permissions()
.list(
fileId=file_id,
supportsAllDrives=True,
fields="permissions(id, type, role)",
)
.execute
)
anyone_perms = [
p
for p in current_permissions.get("permissions", [])
if p.get("type") == "anyone"
]
if link_sharing == "off":
if anyone_perms:
for perm in anyone_perms:
await asyncio.to_thread(
service.permissions()
.delete(
fileId=file_id,
permissionId=perm["id"],
supportsAllDrives=True,
)
.execute
)
changes_made.append(
" - Link sharing: disabled (restricted to specific people)"
)
else:
changes_made.append(" - Link sharing: already off (no change)")
else:
if anyone_perms:
await asyncio.to_thread(
service.permissions()
.update(
fileId=file_id,
permissionId=anyone_perms[0]["id"],
body={
"role": link_sharing,
"allowFileDiscovery": False,
},
supportsAllDrives=True,
fields="id, type, role",
)
.execute
)
changes_made.append(f" - Link sharing: updated to '{link_sharing}'")
else:
await asyncio.to_thread(
service.permissions()
.create(
fileId=file_id,
body={
"type": "anyone",
"role": link_sharing,
"allowFileDiscovery": False,
},
supportsAllDrives=True,
fields="id, type, role",
)
.execute
)
changes_made.append(f" - Link sharing: enabled as '{link_sharing}'")
output_parts.append("Changes:")
if changes_made:
output_parts.extend(changes_made)
else:
output_parts.append(" - No changes (already configured)")
output_parts.extend(["", f"View link: {file_metadata.get('webViewLink', 'N/A')}"])
return "\n".join(output_parts)