pr feedback
This commit is contained in:
@@ -11,8 +11,9 @@ import httpx
|
||||
import base64
|
||||
import ipaddress
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||
from tempfile import NamedTemporaryFile
|
||||
from urllib.parse import urljoin, urlparse, urlunparse
|
||||
from urllib.request import url2pathname
|
||||
@@ -44,6 +45,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
DOWNLOAD_CHUNK_SIZE_BYTES = 256 * 1024 # 256 KB
|
||||
UPLOAD_CHUNK_SIZE_BYTES = 5 * 1024 * 1024 # 5 MB (Google recommended minimum)
|
||||
MAX_DOWNLOAD_BYTES = 2 * 1024 * 1024 * 1024 # 2 GB safety limit for URL downloads
|
||||
|
||||
|
||||
@server.tool()
|
||||
@@ -600,37 +602,49 @@ async def create_drive_file(
|
||||
.execute
|
||||
)
|
||||
else:
|
||||
# Use NamedTemporaryFile to stream download and upload
|
||||
# Stream download to temp file with SSRF protection, then upload
|
||||
with NamedTemporaryFile() as temp_file:
|
||||
total_bytes = 0
|
||||
# Use SSRF-safe fetch (validates each redirect target)
|
||||
resp = await _ssrf_safe_fetch(fileUrl)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch file from URL: {fileUrl} (status {resp.status_code})"
|
||||
)
|
||||
content_type = None
|
||||
|
||||
file_data = resp.content
|
||||
await asyncio.to_thread(temp_file.write, file_data)
|
||||
total_bytes = len(file_data)
|
||||
async with _ssrf_safe_stream(fileUrl) as resp:
|
||||
if resp.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to fetch file from URL: {fileUrl} "
|
||||
f"(status {resp.status_code})"
|
||||
)
|
||||
|
||||
content_type = resp.headers.get("Content-Type")
|
||||
|
||||
async for chunk in resp.aiter_bytes(
|
||||
chunk_size=DOWNLOAD_CHUNK_SIZE_BYTES
|
||||
):
|
||||
total_bytes += len(chunk)
|
||||
if total_bytes > MAX_DOWNLOAD_BYTES:
|
||||
raise Exception(
|
||||
f"Download exceeded {MAX_DOWNLOAD_BYTES} byte limit"
|
||||
)
|
||||
await asyncio.to_thread(temp_file.write, chunk)
|
||||
|
||||
logger.info(
|
||||
f"[create_drive_file] Downloaded {total_bytes} bytes from URL before upload."
|
||||
f"[create_drive_file] Downloaded {total_bytes} bytes "
|
||||
f"from URL before upload."
|
||||
)
|
||||
|
||||
# Try to get MIME type from Content-Type header
|
||||
content_type = resp.headers.get("Content-Type")
|
||||
if content_type and content_type != "application/octet-stream":
|
||||
if (
|
||||
content_type
|
||||
and content_type != "application/octet-stream"
|
||||
):
|
||||
mime_type = content_type
|
||||
file_metadata["mimeType"] = mime_type
|
||||
logger.info(
|
||||
f"[create_drive_file] Using MIME type from Content-Type header: {mime_type}"
|
||||
f"[create_drive_file] Using MIME type from "
|
||||
f"Content-Type header: {mime_type}"
|
||||
)
|
||||
|
||||
# Reset file pointer to beginning for upload
|
||||
temp_file.seek(0)
|
||||
|
||||
# Upload with chunking
|
||||
media = MediaIoBaseUpload(
|
||||
temp_file,
|
||||
mimetype=mime_type,
|
||||
@@ -894,6 +908,87 @@ async def _ssrf_safe_fetch(url: str, *, stream: bool = False) -> httpx.Response:
|
||||
raise Exception(f"Too many redirects (max {max_redirects}) fetching {url}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _ssrf_safe_stream(url: str) -> AsyncIterator[httpx.Response]:
|
||||
"""
|
||||
SSRF-safe streaming fetch: validates each redirect target against private
|
||||
networks, then streams the final response body without buffering it all
|
||||
in memory.
|
||||
|
||||
Usage::
|
||||
|
||||
async with _ssrf_safe_stream(file_url) as resp:
|
||||
async for chunk in resp.aiter_bytes(chunk_size=DOWNLOAD_CHUNK_SIZE_BYTES):
|
||||
...
|
||||
"""
|
||||
max_redirects = 10
|
||||
current_url = url
|
||||
|
||||
# Resolve redirects manually so every hop is SSRF-validated
|
||||
for _ in range(max_redirects):
|
||||
parsed = urlparse(current_url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"Only http:// and https:// are supported: {current_url}")
|
||||
if not parsed.hostname:
|
||||
raise ValueError(f"Invalid URL: missing hostname ({current_url})")
|
||||
|
||||
resolved_ips = _validate_url_not_internal(current_url)
|
||||
host_header = _format_host_header(parsed.hostname, parsed.scheme, parsed.port)
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
resp: Optional[httpx.Response] = None
|
||||
for resolved_ip in resolved_ips:
|
||||
pinned_url = _build_pinned_url(parsed, resolved_ip)
|
||||
try:
|
||||
client = httpx.AsyncClient(follow_redirects=False, trust_env=False)
|
||||
request = client.build_request(
|
||||
"GET",
|
||||
pinned_url,
|
||||
headers={"Host": host_header},
|
||||
extensions={"sni_hostname": parsed.hostname},
|
||||
)
|
||||
resp = await client.send(request, stream=True)
|
||||
break
|
||||
except httpx.HTTPError as exc:
|
||||
last_error = exc
|
||||
await client.aclose()
|
||||
logger.warning(
|
||||
f"[ssrf_safe_stream] Failed via IP {resolved_ip} for "
|
||||
f"{parsed.hostname}: {exc}"
|
||||
)
|
||||
|
||||
if resp is None:
|
||||
raise Exception(
|
||||
f"Failed to fetch URL after trying {len(resolved_ips)} validated IP(s): "
|
||||
f"{current_url}"
|
||||
) from last_error
|
||||
|
||||
if resp.status_code in (301, 302, 303, 307, 308):
|
||||
location = resp.headers.get("location")
|
||||
await resp.aclose()
|
||||
await client.aclose()
|
||||
if not location:
|
||||
raise Exception(f"Redirect with no Location header from {current_url}")
|
||||
location = urljoin(current_url, location)
|
||||
redirect_parsed = urlparse(location)
|
||||
if redirect_parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Redirect to disallowed scheme: {redirect_parsed.scheme}"
|
||||
)
|
||||
current_url = location
|
||||
continue
|
||||
|
||||
# Non-redirect — yield the streaming response
|
||||
try:
|
||||
yield resp
|
||||
finally:
|
||||
await resp.aclose()
|
||||
await client.aclose()
|
||||
return
|
||||
|
||||
raise Exception(f"Too many redirects (max {max_redirects}) fetching {url}")
|
||||
|
||||
|
||||
def _detect_source_format(file_name: str, content: Optional[str] = None) -> str:
|
||||
"""
|
||||
Detect the source MIME type based on file extension.
|
||||
|
||||
Reference in New Issue
Block a user