Merge branch 'main' of github.com:taylorwilsdon/google_workspace_mcp into feature/create-drive-folder
This commit is contained in:
@@ -11,10 +11,11 @@ import httpx
|
||||
import base64
|
||||
import ipaddress
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||
from tempfile import NamedTemporaryFile
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urljoin, urlparse, urlunparse
|
||||
from urllib.request import url2pathname
|
||||
from pathlib import Path
|
||||
|
||||
@@ -24,7 +25,7 @@ from googleapiclient.http import MediaIoBaseDownload, MediaIoBaseUpload
|
||||
from auth.service_decorator import require_google_service
|
||||
from auth.oauth_config import is_stateless_mode
|
||||
from core.attachment_storage import get_attachment_storage, get_attachment_url
|
||||
from core.utils import extract_office_xml_text, handle_http_errors
|
||||
from core.utils import extract_office_xml_text, handle_http_errors, validate_file_path
|
||||
from core.server import server
|
||||
from core.config import get_transport_mode
|
||||
from gdrive.drive_helpers import (
|
||||
@@ -44,6 +45,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
DOWNLOAD_CHUNK_SIZE_BYTES = 256 * 1024 # 256 KB
|
||||
UPLOAD_CHUNK_SIZE_BYTES = 5 * 1024 * 1024 # 5 MB (Google recommended minimum)
|
||||
MAX_DOWNLOAD_BYTES = 2 * 1024 * 1024 * 1024 # 2 GB safety limit for URL downloads
|
||||
|
||||
|
||||
@server.tool()
|
||||
@@ -223,25 +225,28 @@ async def get_drive_file_download_url(
|
||||
export_format: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Gets a download URL for a Google Drive file. The file is prepared and made available via HTTP URL.
|
||||
Downloads a Google Drive file and saves it to local disk.
|
||||
|
||||
In stdio mode, returns the local file path for direct access.
|
||||
In HTTP mode, returns a temporary download URL (valid for 1 hour).
|
||||
|
||||
For Google native files (Docs, Sheets, Slides), exports to a useful format:
|
||||
• Google Docs → PDF (default) or DOCX if export_format='docx'
|
||||
• Google Sheets → XLSX (default), PDF if export_format='pdf', or CSV if export_format='csv'
|
||||
• Google Slides → PDF (default) or PPTX if export_format='pptx'
|
||||
- Google Docs -> PDF (default) or DOCX if export_format='docx'
|
||||
- Google Sheets -> XLSX (default), PDF if export_format='pdf', or CSV if export_format='csv'
|
||||
- Google Slides -> PDF (default) or PPTX if export_format='pptx'
|
||||
|
||||
For other files, downloads the original file format.
|
||||
|
||||
Args:
|
||||
user_google_email: The user's Google email address. Required.
|
||||
file_id: The Google Drive file ID to get a download URL for.
|
||||
file_id: The Google Drive file ID to download.
|
||||
export_format: Optional export format for Google native files.
|
||||
Options: 'pdf', 'docx', 'xlsx', 'csv', 'pptx'.
|
||||
If not specified, uses sensible defaults (PDF for Docs/Slides, XLSX for Sheets).
|
||||
For Sheets: supports 'csv', 'pdf', or 'xlsx' (default).
|
||||
|
||||
Returns:
|
||||
str: Download URL and file metadata. The file is available at the URL for 1 hour.
|
||||
str: File metadata with either a local file path or download URL.
|
||||
"""
|
||||
logger.info(
|
||||
f"[get_drive_file_download_url] Invoked. File ID: '{file_id}', Export format: {export_format}"
|
||||
@@ -346,41 +351,45 @@ async def get_drive_file_download_url(
|
||||
)
|
||||
return "\n".join(result_lines)
|
||||
|
||||
# Save file and generate URL
|
||||
# Save file to local disk and return file path
|
||||
try:
|
||||
storage = get_attachment_storage()
|
||||
|
||||
# Encode bytes to base64 (as expected by AttachmentStorage)
|
||||
base64_data = base64.urlsafe_b64encode(file_content_bytes).decode("utf-8")
|
||||
|
||||
# Save attachment
|
||||
saved_file_id = storage.save_attachment(
|
||||
# Save attachment to local disk
|
||||
result = storage.save_attachment(
|
||||
base64_data=base64_data,
|
||||
filename=output_filename,
|
||||
mime_type=output_mime_type,
|
||||
)
|
||||
|
||||
# Generate URL
|
||||
download_url = get_attachment_url(saved_file_id)
|
||||
|
||||
result_lines = [
|
||||
"File downloaded successfully!",
|
||||
f"File: {file_name}",
|
||||
f"File ID: {file_id}",
|
||||
f"Size: {size_kb:.1f} KB ({size_bytes} bytes)",
|
||||
f"MIME Type: {output_mime_type}",
|
||||
f"\n📎 Download URL: {download_url}",
|
||||
"\nThe file has been saved and is available at the URL above.",
|
||||
"The file will expire after 1 hour.",
|
||||
]
|
||||
|
||||
if get_transport_mode() == "stdio":
|
||||
result_lines.append(f"\n📎 Saved to: {result.path}")
|
||||
result_lines.append(
|
||||
"\nThe file has been saved to disk and can be accessed directly via the file path."
|
||||
)
|
||||
else:
|
||||
download_url = get_attachment_url(result.file_id)
|
||||
result_lines.append(f"\n📎 Download URL: {download_url}")
|
||||
result_lines.append("\nThe file will expire after 1 hour.")
|
||||
|
||||
if export_mime_type:
|
||||
result_lines.append(
|
||||
f"\nNote: Google native file exported to {output_mime_type} format."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[get_drive_file_download_url] Successfully saved {size_kb:.1f} KB file as {saved_file_id}"
|
||||
f"[get_drive_file_download_url] Successfully saved {size_kb:.1f} KB file to {result.path}"
|
||||
)
|
||||
return "\n".join(result_lines)
|
||||
|
||||
@@ -604,8 +613,8 @@ async def create_drive_file(
|
||||
raw_path = f"//{netloc}{raw_path}"
|
||||
file_path = url2pathname(raw_path)
|
||||
|
||||
# Verify file exists
|
||||
path_obj = Path(file_path)
|
||||
# Validate path safety and verify file exists
|
||||
path_obj = validate_file_path(file_path)
|
||||
if not path_obj.exists():
|
||||
extra = (
|
||||
" The server is running via streamable-http, so file:// URLs must point to files inside the container or remote host."
|
||||
@@ -650,21 +659,20 @@ async def create_drive_file(
|
||||
elif parsed_url.scheme in ("http", "https"):
|
||||
# when running in stateless mode, deployment may not have access to local file system
|
||||
if is_stateless_mode():
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.get(fileUrl)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch file from URL: {fileUrl} (status {resp.status_code})"
|
||||
)
|
||||
file_data = await resp.aread()
|
||||
# Try to get MIME type from Content-Type header
|
||||
content_type = resp.headers.get("Content-Type")
|
||||
if content_type and content_type != "application/octet-stream":
|
||||
mime_type = content_type
|
||||
file_metadata["mimeType"] = content_type
|
||||
logger.info(
|
||||
f"[create_drive_file] Using MIME type from Content-Type header: {content_type}"
|
||||
)
|
||||
resp = await _ssrf_safe_fetch(fileUrl)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch file from URL: {fileUrl} (status {resp.status_code})"
|
||||
)
|
||||
file_data = resp.content
|
||||
# Try to get MIME type from Content-Type header
|
||||
content_type = resp.headers.get("Content-Type")
|
||||
if content_type and content_type != "application/octet-stream":
|
||||
mime_type = content_type
|
||||
file_metadata["mimeType"] = content_type
|
||||
logger.info(
|
||||
f"[create_drive_file] Using MIME type from Content-Type header: {content_type}"
|
||||
)
|
||||
|
||||
media = MediaIoBaseUpload(
|
||||
io.BytesIO(file_data),
|
||||
@@ -684,44 +692,46 @@ async def create_drive_file(
|
||||
.execute
|
||||
)
|
||||
else:
|
||||
# Use NamedTemporaryFile to stream download and upload
|
||||
# Stream download to temp file with SSRF protection, then upload
|
||||
with NamedTemporaryFile() as temp_file:
|
||||
total_bytes = 0
|
||||
# follow redirects
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
async with client.stream("GET", fileUrl) as resp:
|
||||
if resp.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch file from URL: {fileUrl} (status {resp.status_code})"
|
||||
)
|
||||
content_type = None
|
||||
|
||||
# Stream download in chunks
|
||||
async for chunk in resp.aiter_bytes(
|
||||
chunk_size=DOWNLOAD_CHUNK_SIZE_BYTES
|
||||
):
|
||||
await asyncio.to_thread(temp_file.write, chunk)
|
||||
total_bytes += len(chunk)
|
||||
|
||||
logger.info(
|
||||
f"[create_drive_file] Downloaded {total_bytes} bytes from URL before upload."
|
||||
async with _ssrf_safe_stream(fileUrl) as resp:
|
||||
if resp.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch file from URL: {fileUrl} "
|
||||
f"(status {resp.status_code})"
|
||||
)
|
||||
|
||||
# Try to get MIME type from Content-Type header
|
||||
content_type = resp.headers.get("Content-Type")
|
||||
if (
|
||||
content_type
|
||||
and content_type != "application/octet-stream"
|
||||
):
|
||||
mime_type = content_type
|
||||
file_metadata["mimeType"] = mime_type
|
||||
logger.info(
|
||||
f"[create_drive_file] Using MIME type from Content-Type header: {mime_type}"
|
||||
content_type = resp.headers.get("Content-Type")
|
||||
|
||||
async for chunk in resp.aiter_bytes(
|
||||
chunk_size=DOWNLOAD_CHUNK_SIZE_BYTES
|
||||
):
|
||||
total_bytes += len(chunk)
|
||||
if total_bytes > MAX_DOWNLOAD_BYTES:
|
||||
raise Exception(
|
||||
f"Download exceeded {MAX_DOWNLOAD_BYTES} byte limit"
|
||||
)
|
||||
await asyncio.to_thread(temp_file.write, chunk)
|
||||
|
||||
logger.info(
|
||||
f"[create_drive_file] Downloaded {total_bytes} bytes "
|
||||
f"from URL before upload."
|
||||
)
|
||||
|
||||
if content_type and content_type != "application/octet-stream":
|
||||
mime_type = content_type
|
||||
file_metadata["mimeType"] = mime_type
|
||||
logger.info(
|
||||
f"[create_drive_file] Using MIME type from "
|
||||
f"Content-Type header: {mime_type}"
|
||||
)
|
||||
|
||||
# Reset file pointer to beginning for upload
|
||||
temp_file.seek(0)
|
||||
|
||||
# Upload with chunking
|
||||
media = MediaIoBaseUpload(
|
||||
temp_file,
|
||||
mimetype=mime_type,
|
||||
@@ -788,16 +798,18 @@ GOOGLE_DOCS_IMPORT_FORMATS = {
|
||||
GOOGLE_DOCS_MIME_TYPE = "application/vnd.google-apps.document"
|
||||
|
||||
|
||||
def _validate_url_not_internal(url: str) -> None:
|
||||
def _resolve_and_validate_host(hostname: str) -> list[str]:
|
||||
"""
|
||||
Validate that a URL doesn't point to internal/private networks (SSRF protection).
|
||||
Resolve a hostname to IP addresses and validate none are private/internal.
|
||||
|
||||
Uses getaddrinfo to handle both IPv4 and IPv6. Fails closed on DNS errors.
|
||||
|
||||
Returns:
|
||||
list[str]: Validated resolved IP address strings.
|
||||
|
||||
Raises:
|
||||
ValueError: If URL points to localhost or private IP ranges
|
||||
ValueError: If hostname resolves to private/internal IPs or DNS fails.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname
|
||||
|
||||
if not hostname:
|
||||
raise ValueError("Invalid URL: no hostname")
|
||||
|
||||
@@ -805,15 +817,266 @@ def _validate_url_not_internal(url: str) -> None:
|
||||
if hostname.lower() in ("localhost", "127.0.0.1", "::1", "0.0.0.0"):
|
||||
raise ValueError("URLs pointing to localhost are not allowed")
|
||||
|
||||
# Resolve hostname and check if it's a private IP
|
||||
# Resolve hostname using getaddrinfo (handles both IPv4 and IPv6)
|
||||
try:
|
||||
ip = ipaddress.ip_address(socket.gethostbyname(hostname))
|
||||
if ip.is_private or ip.is_loopback or ip.is_reserved:
|
||||
addr_infos = socket.getaddrinfo(hostname, None)
|
||||
except socket.gaierror as e:
|
||||
raise ValueError(
|
||||
f"Cannot resolve hostname '{hostname}': {e}. "
|
||||
"Refusing request (fail-closed)."
|
||||
)
|
||||
|
||||
if not addr_infos:
|
||||
raise ValueError(f"No addresses found for hostname: {hostname}")
|
||||
|
||||
resolved_ips: list[str] = []
|
||||
seen_ips: set[str] = set()
|
||||
for _family, _type, _proto, _canonname, sockaddr in addr_infos:
|
||||
ip_str = sockaddr[0]
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
if not ip.is_global:
|
||||
raise ValueError(
|
||||
f"URLs pointing to private/internal networks are not allowed: {hostname}"
|
||||
f"URLs pointing to private/internal networks are not allowed: "
|
||||
f"{hostname} resolves to {ip_str}"
|
||||
)
|
||||
except socket.gaierror:
|
||||
pass # Can't resolve, let httpx handle it
|
||||
if ip_str not in seen_ips:
|
||||
seen_ips.add(ip_str)
|
||||
resolved_ips.append(ip_str)
|
||||
|
||||
return resolved_ips
|
||||
|
||||
|
||||
def _validate_url_not_internal(url: str) -> list[str]:
|
||||
"""
|
||||
Validate that a URL doesn't point to internal/private networks (SSRF protection).
|
||||
|
||||
Returns:
|
||||
list[str]: Validated resolved IP addresses for the hostname.
|
||||
|
||||
Raises:
|
||||
ValueError: If URL points to localhost or private IP ranges.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
return _resolve_and_validate_host(parsed.hostname)
|
||||
|
||||
|
||||
def _format_host_header(hostname: str, scheme: str, port: Optional[int]) -> str:
|
||||
"""Format the Host header value for IPv4/IPv6 hostnames."""
|
||||
host_value = hostname
|
||||
if ":" in host_value and not host_value.startswith("["):
|
||||
host_value = f"[{host_value}]"
|
||||
|
||||
is_default_port = (scheme == "http" and (port is None or port == 80)) or (
|
||||
scheme == "https" and (port is None or port == 443)
|
||||
)
|
||||
if not is_default_port and port is not None:
|
||||
host_value = f"{host_value}:{port}"
|
||||
return host_value
|
||||
|
||||
|
||||
def _build_pinned_url(parsed_url, ip_address_str: str) -> str:
|
||||
"""Build a URL that targets a resolved IP while preserving path/query."""
|
||||
pinned_host = ip_address_str
|
||||
if ":" in pinned_host and not pinned_host.startswith("["):
|
||||
pinned_host = f"[{pinned_host}]"
|
||||
|
||||
userinfo = ""
|
||||
if parsed_url.username is not None:
|
||||
userinfo = parsed_url.username
|
||||
if parsed_url.password is not None:
|
||||
userinfo += f":{parsed_url.password}"
|
||||
userinfo += "@"
|
||||
|
||||
port_part = f":{parsed_url.port}" if parsed_url.port is not None else ""
|
||||
netloc = f"{userinfo}{pinned_host}{port_part}"
|
||||
|
||||
path = parsed_url.path or "/"
|
||||
return urlunparse(
|
||||
(
|
||||
parsed_url.scheme,
|
||||
netloc,
|
||||
path,
|
||||
parsed_url.params,
|
||||
parsed_url.query,
|
||||
parsed_url.fragment,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_url_with_pinned_ip(url: str) -> httpx.Response:
|
||||
"""
|
||||
Fetch URL content by connecting to a validated, pre-resolved IP address.
|
||||
|
||||
This prevents DNS rebinding between validation and the outbound connection.
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise ValueError(f"Only http:// and https:// are supported: {url}")
|
||||
if not parsed_url.hostname:
|
||||
raise ValueError(f"Invalid URL: missing hostname ({url})")
|
||||
|
||||
resolved_ips = _validate_url_not_internal(url)
|
||||
host_header = _format_host_header(
|
||||
parsed_url.hostname, parsed_url.scheme, parsed_url.port
|
||||
)
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
for resolved_ip in resolved_ips:
|
||||
pinned_url = _build_pinned_url(parsed_url, resolved_ip)
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
follow_redirects=False, trust_env=False
|
||||
) as client:
|
||||
request = client.build_request(
|
||||
"GET",
|
||||
pinned_url,
|
||||
headers={"Host": host_header},
|
||||
extensions={"sni_hostname": parsed_url.hostname},
|
||||
)
|
||||
return await client.send(request)
|
||||
except httpx.HTTPError as exc:
|
||||
last_error = exc
|
||||
logger.warning(
|
||||
f"[ssrf_safe_fetch] Failed request via resolved IP {resolved_ip} for host "
|
||||
f"{parsed_url.hostname}: {exc}"
|
||||
)
|
||||
|
||||
raise Exception(
|
||||
f"Failed to fetch URL after trying {len(resolved_ips)} validated IP(s): {url}"
|
||||
) from last_error
|
||||
|
||||
|
||||
async def _ssrf_safe_fetch(url: str, *, stream: bool = False) -> httpx.Response:
|
||||
"""
|
||||
Fetch a URL with SSRF protection that covers redirects and DNS rebinding.
|
||||
|
||||
Validates the initial URL and every redirect target against private/internal
|
||||
networks. Disables automatic redirect following and handles redirects manually.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch.
|
||||
stream: If True, returns a streaming response (caller must manage context).
|
||||
|
||||
Returns:
|
||||
httpx.Response with the final response content.
|
||||
|
||||
Raises:
|
||||
ValueError: If any URL in the redirect chain points to a private network.
|
||||
Exception: If the HTTP request fails.
|
||||
"""
|
||||
if stream:
|
||||
raise ValueError("Streaming mode is not supported by _ssrf_safe_fetch.")
|
||||
|
||||
max_redirects = 10
|
||||
current_url = url
|
||||
|
||||
for _ in range(max_redirects):
|
||||
resp = await _fetch_url_with_pinned_ip(current_url)
|
||||
|
||||
if resp.status_code in (301, 302, 303, 307, 308):
|
||||
location = resp.headers.get("location")
|
||||
if not location:
|
||||
raise Exception(f"Redirect with no Location header from {current_url}")
|
||||
|
||||
# Resolve relative redirects against the current URL
|
||||
location = urljoin(current_url, location)
|
||||
|
||||
redirect_parsed = urlparse(location)
|
||||
if redirect_parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Redirect to disallowed scheme: {redirect_parsed.scheme}"
|
||||
)
|
||||
|
||||
current_url = location
|
||||
continue
|
||||
|
||||
return resp
|
||||
|
||||
raise Exception(f"Too many redirects (max {max_redirects}) fetching {url}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _ssrf_safe_stream(url: str) -> AsyncIterator[httpx.Response]:
|
||||
"""
|
||||
SSRF-safe streaming fetch: validates each redirect target against private
|
||||
networks, then streams the final response body without buffering it all
|
||||
in memory.
|
||||
|
||||
Usage::
|
||||
|
||||
async with _ssrf_safe_stream(file_url) as resp:
|
||||
async for chunk in resp.aiter_bytes(chunk_size=DOWNLOAD_CHUNK_SIZE_BYTES):
|
||||
...
|
||||
"""
|
||||
max_redirects = 10
|
||||
current_url = url
|
||||
|
||||
# Resolve redirects manually so every hop is SSRF-validated
|
||||
for _ in range(max_redirects):
|
||||
parsed = urlparse(current_url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"Only http:// and https:// are supported: {current_url}")
|
||||
if not parsed.hostname:
|
||||
raise ValueError(f"Invalid URL: missing hostname ({current_url})")
|
||||
|
||||
resolved_ips = _validate_url_not_internal(current_url)
|
||||
host_header = _format_host_header(parsed.hostname, parsed.scheme, parsed.port)
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
resp: Optional[httpx.Response] = None
|
||||
for resolved_ip in resolved_ips:
|
||||
pinned_url = _build_pinned_url(parsed, resolved_ip)
|
||||
client = httpx.AsyncClient(follow_redirects=False, trust_env=False)
|
||||
try:
|
||||
request = client.build_request(
|
||||
"GET",
|
||||
pinned_url,
|
||||
headers={"Host": host_header},
|
||||
extensions={"sni_hostname": parsed.hostname},
|
||||
)
|
||||
resp = await client.send(request, stream=True)
|
||||
break
|
||||
except httpx.HTTPError as exc:
|
||||
last_error = exc
|
||||
await client.aclose()
|
||||
logger.warning(
|
||||
f"[ssrf_safe_stream] Failed via IP {resolved_ip} for "
|
||||
f"{parsed.hostname}: {exc}"
|
||||
)
|
||||
except Exception:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
if resp is None:
|
||||
raise Exception(
|
||||
f"Failed to fetch URL after trying {len(resolved_ips)} validated IP(s): "
|
||||
f"{current_url}"
|
||||
) from last_error
|
||||
|
||||
if resp.status_code in (301, 302, 303, 307, 308):
|
||||
location = resp.headers.get("location")
|
||||
await resp.aclose()
|
||||
await client.aclose()
|
||||
if not location:
|
||||
raise Exception(f"Redirect with no Location header from {current_url}")
|
||||
location = urljoin(current_url, location)
|
||||
redirect_parsed = urlparse(location)
|
||||
if redirect_parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Redirect to disallowed scheme: {redirect_parsed.scheme}"
|
||||
)
|
||||
current_url = location
|
||||
continue
|
||||
|
||||
# Non-redirect — yield the streaming response
|
||||
try:
|
||||
yield resp
|
||||
finally:
|
||||
await resp.aclose()
|
||||
await client.aclose()
|
||||
return
|
||||
|
||||
raise Exception(f"Too many redirects (max {max_redirects}) fetching {url}")
|
||||
|
||||
|
||||
def _detect_source_format(file_name: str, content: Optional[str] = None) -> str:
|
||||
@@ -945,7 +1208,7 @@ async def import_to_google_doc(
|
||||
f"file_path should be a local path or file:// URL, got: {file_path}"
|
||||
)
|
||||
|
||||
path_obj = Path(actual_path)
|
||||
path_obj = validate_file_path(actual_path)
|
||||
if not path_obj.exists():
|
||||
raise FileNotFoundError(f"File not found: {actual_path}")
|
||||
if not path_obj.is_file():
|
||||
@@ -967,16 +1230,13 @@ async def import_to_google_doc(
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise ValueError(f"file_url must be http:// or https://, got: {file_url}")
|
||||
|
||||
# SSRF protection: block internal/private network URLs
|
||||
_validate_url_not_internal(file_url)
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.get(file_url)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch file from URL: {file_url} (status {resp.status_code})"
|
||||
)
|
||||
file_data = resp.content
|
||||
# SSRF protection: block internal/private network URLs and validate redirects
|
||||
resp = await _ssrf_safe_fetch(file_url)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch file from URL: {file_url} (status {resp.status_code})"
|
||||
)
|
||||
file_data = resp.content
|
||||
|
||||
logger.info(
|
||||
f"[import_to_google_doc] Downloaded from URL: {len(file_data)} bytes"
|
||||
@@ -1985,3 +2245,174 @@ async def transfer_drive_ownership(
|
||||
output_parts.extend(["", "Note: Previous owner now has editor access."])
|
||||
|
||||
return "\n".join(output_parts)
|
||||
|
||||
|
||||
@server.tool()
|
||||
@handle_http_errors(
|
||||
"set_drive_file_permissions", is_read_only=False, service_type="drive"
|
||||
)
|
||||
@require_google_service("drive", "drive_file")
|
||||
async def set_drive_file_permissions(
|
||||
service,
|
||||
user_google_email: str,
|
||||
file_id: str,
|
||||
link_sharing: Optional[str] = None,
|
||||
writers_can_share: Optional[bool] = None,
|
||||
copy_requires_writer_permission: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Sets file-level sharing settings and controls link sharing for a Google Drive file or folder.
|
||||
|
||||
This is a high-level tool for the most common permission changes. Use this to toggle
|
||||
"anyone with the link" access or configure file-level sharing behavior. For managing
|
||||
individual user/group permissions, use share_drive_file or update_drive_permission instead.
|
||||
|
||||
Args:
|
||||
user_google_email (str): The user's Google email address. Required.
|
||||
file_id (str): The ID of the file or folder. Required.
|
||||
link_sharing (Optional[str]): Control "anyone with the link" access for the file.
|
||||
- "off": Disable "anyone with the link" access for this file.
|
||||
- "reader": Anyone with the link can view.
|
||||
- "commenter": Anyone with the link can comment.
|
||||
- "writer": Anyone with the link can edit.
|
||||
writers_can_share (Optional[bool]): Whether editors can change permissions and share.
|
||||
If False, only the owner can share. Defaults to None (no change).
|
||||
copy_requires_writer_permission (Optional[bool]): Whether viewers and commenters
|
||||
are prevented from copying, printing, or downloading. Defaults to None (no change).
|
||||
|
||||
Returns:
|
||||
str: Summary of all permission changes applied to the file.
|
||||
"""
|
||||
logger.info(
|
||||
f"[set_drive_file_permissions] Invoked. Email: '{user_google_email}', "
|
||||
f"File ID: '{file_id}', Link sharing: '{link_sharing}', "
|
||||
f"Writers can share: {writers_can_share}, Copy restriction: {copy_requires_writer_permission}"
|
||||
)
|
||||
|
||||
if (
|
||||
link_sharing is None
|
||||
and writers_can_share is None
|
||||
and copy_requires_writer_permission is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Must provide at least one of: link_sharing, writers_can_share, copy_requires_writer_permission"
|
||||
)
|
||||
|
||||
valid_link_sharing = {"off", "reader", "commenter", "writer"}
|
||||
if link_sharing is not None and link_sharing not in valid_link_sharing:
|
||||
raise ValueError(
|
||||
f"Invalid link_sharing '{link_sharing}'. Must be one of: {', '.join(sorted(valid_link_sharing))}"
|
||||
)
|
||||
|
||||
resolved_file_id, file_metadata = await resolve_drive_item(
|
||||
service, file_id, extra_fields="name, webViewLink"
|
||||
)
|
||||
file_id = resolved_file_id
|
||||
file_name = file_metadata.get("name", "Unknown")
|
||||
|
||||
output_parts = [f"Permission settings updated for '{file_name}'", ""]
|
||||
changes_made = []
|
||||
|
||||
# Handle file-level settings via files().update()
|
||||
file_update_body = {}
|
||||
if writers_can_share is not None:
|
||||
file_update_body["writersCanShare"] = writers_can_share
|
||||
if copy_requires_writer_permission is not None:
|
||||
file_update_body["copyRequiresWriterPermission"] = (
|
||||
copy_requires_writer_permission
|
||||
)
|
||||
|
||||
if file_update_body:
|
||||
await asyncio.to_thread(
|
||||
service.files()
|
||||
.update(
|
||||
fileId=file_id,
|
||||
body=file_update_body,
|
||||
supportsAllDrives=True,
|
||||
fields="id",
|
||||
)
|
||||
.execute
|
||||
)
|
||||
if writers_can_share is not None:
|
||||
state = "allowed" if writers_can_share else "restricted to owner"
|
||||
changes_made.append(f" - Editors sharing: {state}")
|
||||
if copy_requires_writer_permission is not None:
|
||||
state = "restricted" if copy_requires_writer_permission else "allowed"
|
||||
changes_made.append(f" - Viewers copy/print/download: {state}")
|
||||
|
||||
# Handle link sharing via permissions API
|
||||
if link_sharing is not None:
|
||||
current_permissions = await asyncio.to_thread(
|
||||
service.permissions()
|
||||
.list(
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
fields="permissions(id, type, role)",
|
||||
)
|
||||
.execute
|
||||
)
|
||||
anyone_perms = [
|
||||
p
|
||||
for p in current_permissions.get("permissions", [])
|
||||
if p.get("type") == "anyone"
|
||||
]
|
||||
|
||||
if link_sharing == "off":
|
||||
if anyone_perms:
|
||||
for perm in anyone_perms:
|
||||
await asyncio.to_thread(
|
||||
service.permissions()
|
||||
.delete(
|
||||
fileId=file_id,
|
||||
permissionId=perm["id"],
|
||||
supportsAllDrives=True,
|
||||
)
|
||||
.execute
|
||||
)
|
||||
changes_made.append(
|
||||
" - Link sharing: disabled (restricted to specific people)"
|
||||
)
|
||||
else:
|
||||
changes_made.append(" - Link sharing: already off (no change)")
|
||||
else:
|
||||
if anyone_perms:
|
||||
await asyncio.to_thread(
|
||||
service.permissions()
|
||||
.update(
|
||||
fileId=file_id,
|
||||
permissionId=anyone_perms[0]["id"],
|
||||
body={
|
||||
"role": link_sharing,
|
||||
"allowFileDiscovery": False,
|
||||
},
|
||||
supportsAllDrives=True,
|
||||
fields="id, type, role",
|
||||
)
|
||||
.execute
|
||||
)
|
||||
changes_made.append(f" - Link sharing: updated to '{link_sharing}'")
|
||||
else:
|
||||
await asyncio.to_thread(
|
||||
service.permissions()
|
||||
.create(
|
||||
fileId=file_id,
|
||||
body={
|
||||
"type": "anyone",
|
||||
"role": link_sharing,
|
||||
"allowFileDiscovery": False,
|
||||
},
|
||||
supportsAllDrives=True,
|
||||
fields="id, type, role",
|
||||
)
|
||||
.execute
|
||||
)
|
||||
changes_made.append(f" - Link sharing: enabled as '{link_sharing}'")
|
||||
|
||||
output_parts.append("Changes:")
|
||||
if changes_made:
|
||||
output_parts.extend(changes_made)
|
||||
else:
|
||||
output_parts.append(" - No changes (already configured)")
|
||||
output_parts.extend(["", f"View link: {file_metadata.get('webViewLink', 'N/A')}"])
|
||||
|
||||
return "\n".join(output_parts)
|
||||
|
||||
Reference in New Issue
Block a user