pr feedback

This commit is contained in:
Taylor Wilsdon
2026-02-13 16:12:00 -05:00
parent 5280f3c634
commit dffdb7ffa7
2 changed files with 133 additions and 17 deletions

View File

@@ -11,8 +11,9 @@ import httpx
import base64
import ipaddress
import socket
from contextlib import asynccontextmanager
from typing import Optional, List, Dict, Any
from typing import AsyncIterator, Optional, List, Dict, Any
from tempfile import NamedTemporaryFile
from urllib.parse import urljoin, urlparse, urlunparse
from urllib.request import url2pathname
@@ -44,6 +45,7 @@ logger = logging.getLogger(__name__)
DOWNLOAD_CHUNK_SIZE_BYTES = 256 * 1024 # 256 KB
UPLOAD_CHUNK_SIZE_BYTES = 5 * 1024 * 1024 # 5 MB (Google recommended minimum)
MAX_DOWNLOAD_BYTES = 2 * 1024 * 1024 * 1024 # 2 GB safety limit for URL downloads
@server.tool()
@@ -600,37 +602,49 @@ async def create_drive_file(
.execute
)
else:
# Use NamedTemporaryFile to stream download and upload
# Stream download to temp file with SSRF protection, then upload
with NamedTemporaryFile() as temp_file:
total_bytes = 0
# Use SSRF-safe fetch (validates each redirect target)
resp = await _ssrf_safe_fetch(fileUrl)
if resp.status_code != 200:
raise Exception(
f"Failed to fetch file from URL: {fileUrl} (status {resp.status_code})"
)
content_type = None
file_data = resp.content
await asyncio.to_thread(temp_file.write, file_data)
total_bytes = len(file_data)
async with _ssrf_safe_stream(fileUrl) as resp:
if resp.status_code != 200:
raise Exception(
f"Failed to fetch file from URL: {fileUrl} "
f"(status {resp.status_code})"
)
content_type = resp.headers.get("Content-Type")
async for chunk in resp.aiter_bytes(
chunk_size=DOWNLOAD_CHUNK_SIZE_BYTES
):
total_bytes += len(chunk)
if total_bytes > MAX_DOWNLOAD_BYTES:
raise Exception(
f"Download exceeded {MAX_DOWNLOAD_BYTES} byte limit"
)
await asyncio.to_thread(temp_file.write, chunk)
logger.info(
f"[create_drive_file] Downloaded {total_bytes} bytes from URL before upload."
f"[create_drive_file] Downloaded {total_bytes} bytes "
f"from URL before upload."
)
# Try to get MIME type from Content-Type header
content_type = resp.headers.get("Content-Type")
if content_type and content_type != "application/octet-stream":
if (
content_type
and content_type != "application/octet-stream"
):
mime_type = content_type
file_metadata["mimeType"] = mime_type
logger.info(
f"[create_drive_file] Using MIME type from Content-Type header: {mime_type}"
f"[create_drive_file] Using MIME type from "
f"Content-Type header: {mime_type}"
)
# Reset file pointer to beginning for upload
temp_file.seek(0)
# Upload with chunking
media = MediaIoBaseUpload(
temp_file,
mimetype=mime_type,
@@ -894,6 +908,87 @@ async def _ssrf_safe_fetch(url: str, *, stream: bool = False) -> httpx.Response:
raise Exception(f"Too many redirects (max {max_redirects}) fetching {url}")
@asynccontextmanager
async def _ssrf_safe_stream(url: str) -> AsyncIterator[httpx.Response]:
"""
SSRF-safe streaming fetch: validates each redirect target against private
networks, then streams the final response body without buffering it all
in memory.
Usage::
async with _ssrf_safe_stream(file_url) as resp:
async for chunk in resp.aiter_bytes(chunk_size=DOWNLOAD_CHUNK_SIZE_BYTES):
...
"""
max_redirects = 10
current_url = url
# Resolve redirects manually so every hop is SSRF-validated
for _ in range(max_redirects):
parsed = urlparse(current_url)
if parsed.scheme not in ("http", "https"):
raise ValueError(f"Only http:// and https:// are supported: {current_url}")
if not parsed.hostname:
raise ValueError(f"Invalid URL: missing hostname ({current_url})")
resolved_ips = _validate_url_not_internal(current_url)
host_header = _format_host_header(parsed.hostname, parsed.scheme, parsed.port)
last_error: Optional[Exception] = None
resp: Optional[httpx.Response] = None
for resolved_ip in resolved_ips:
pinned_url = _build_pinned_url(parsed, resolved_ip)
try:
client = httpx.AsyncClient(follow_redirects=False, trust_env=False)
request = client.build_request(
"GET",
pinned_url,
headers={"Host": host_header},
extensions={"sni_hostname": parsed.hostname},
)
resp = await client.send(request, stream=True)
break
except httpx.HTTPError as exc:
last_error = exc
await client.aclose()
logger.warning(
f"[ssrf_safe_stream] Failed via IP {resolved_ip} for "
f"{parsed.hostname}: {exc}"
)
if resp is None:
raise Exception(
f"Failed to fetch URL after trying {len(resolved_ips)} validated IP(s): "
f"{current_url}"
) from last_error
if resp.status_code in (301, 302, 303, 307, 308):
location = resp.headers.get("location")
await resp.aclose()
await client.aclose()
if not location:
raise Exception(f"Redirect with no Location header from {current_url}")
location = urljoin(current_url, location)
redirect_parsed = urlparse(location)
if redirect_parsed.scheme not in ("http", "https"):
raise ValueError(
f"Redirect to disallowed scheme: {redirect_parsed.scheme}"
)
current_url = location
continue
# Non-redirect — yield the streaming response
try:
yield resp
finally:
await resp.aclose()
await client.aclose()
return
raise Exception(f"Too many redirects (max {max_redirects}) fetching {url}")
def _detect_source_format(file_name: str, content: Optional[str] = None) -> str:
"""
Detect the source MIME type based on file extension.