mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-18 05:49:41 +08:00
- Disable aiohttp auto-redirects and re-validate every Location target against the same allowlist used for the initial URL, closing an SSRF vector where an allowed host could redirect to an arbitrary internal endpoint. - Accept subdomains of allowlisted hosts so Hugging Face's LFS CDN (cdn-lfs.huggingface.co et al.) keeps working under the stricter redirect handling. - Pass an explicit ClientTimeout (connect/sock_read) so hung remotes surface as errors instead of blocking the request handler forever. - Log the exception value alongside the traceback on the 500 fallback. - Add positive coverage for normalize_model_relative_path, Civitai URL allowlisting, and the redirect-following / SSRF-rejection branches of open_model_download_response. Co-authored-by: Cursor <cursoragent@cursor.com>
282 lines
9.6 KiB
Python
282 lines
9.6 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import posixpath
|
|
from dataclasses import dataclass
|
|
from pathlib import PurePosixPath
|
|
from urllib.parse import urljoin, urlparse
|
|
|
|
from aiohttp import ClientSession, ClientTimeout
|
|
|
|
import folder_paths
|
|
|
|
|
|
ALLOWED_DOWNLOAD_HOSTS = {"huggingface.co", "civitai.com", "civitai.red"}
|
|
ALLOWED_DOWNLOAD_SUFFIXES = (".safetensors", ".sft", ".ckpt", ".pth", ".pt")
|
|
BLOCKED_MODEL_FOLDERS = {"configs", "custom_nodes"}
|
|
CHUNK_SIZE = 1024 * 1024
|
|
|
|
# Bound the network call so a hung remote eventually surfaces an error
|
|
# instead of blocking the request handler forever. ``sock_read`` is the
|
|
# inter-chunk read timeout, which is the right knob for long downloads:
|
|
# a slow-but-progressing transfer keeps making progress, while a stalled
|
|
# socket fails predictably.
|
|
DOWNLOAD_TIMEOUT = ClientTimeout(total=None, connect=30, sock_connect=30, sock_read=300)
|
|
|
|
# Maximum number of redirects we follow manually. Hugging Face typically
|
|
# redirects ``/resolve/main/...`` to a single CDN URL, so a small budget
|
|
# is enough while still preventing redirect loops.
|
|
MAX_DOWNLOAD_REDIRECTS = 5
|
|
|
|
WHITE_LISTED_DOWNLOAD_URLS = {
|
|
"https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt",
|
|
"https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true",
|
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
|
}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ModelDownloadRequest:
|
|
name: str
|
|
url: str
|
|
directory: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ModelDownloadDestination:
|
|
directory: str
|
|
relative_path: str
|
|
full_path: str
|
|
already_exists: bool = False
|
|
|
|
|
|
class ModelDownloadError(Exception):
|
|
def __init__(self, message: str, status: int = 400):
|
|
super().__init__(message)
|
|
self.status = status
|
|
|
|
|
|
def parse_model_download_request(data) -> ModelDownloadRequest:
|
|
if not isinstance(data, dict):
|
|
raise ModelDownloadError("Expected a JSON object.")
|
|
|
|
name = data.get("name")
|
|
url = data.get("url")
|
|
directory = data.get("directory")
|
|
if not isinstance(name, str) or not isinstance(url, str) or not isinstance(directory, str):
|
|
raise ModelDownloadError("Missing model name, URL, or directory.")
|
|
|
|
name = name.strip()
|
|
url = url.strip()
|
|
directory = directory.strip()
|
|
if not name or not url or not directory:
|
|
raise ModelDownloadError("Model name, URL, and directory are required.")
|
|
|
|
if not is_allowed_model_download_url(url):
|
|
raise ModelDownloadError("Model download URL is not allowed.")
|
|
|
|
relative_path = normalize_model_relative_path(name)
|
|
if not relative_path.lower().endswith(ALLOWED_DOWNLOAD_SUFFIXES):
|
|
raise ModelDownloadError("Model filename extension is not allowed.")
|
|
|
|
return ModelDownloadRequest(name=relative_path, url=url, directory=directory)
|
|
|
|
|
|
def is_allowed_model_download_url(url: str) -> bool:
|
|
"""Return True for URLs we are willing to fetch on behalf of the user.
|
|
|
|
The same predicate is applied to the user-supplied URL and to every
|
|
redirect target, so SSRF via redirects on an allowed host is contained
|
|
to the same allowlist. Subdomains of allowlisted hosts are accepted
|
|
because Hugging Face and Civitai both serve actual file payloads from
|
|
CDN subdomains (e.g. ``cdn-lfs.huggingface.co``).
|
|
"""
|
|
if url in WHITE_LISTED_DOWNLOAD_URLS:
|
|
return True
|
|
|
|
try:
|
|
parsed = urlparse(url)
|
|
except ValueError:
|
|
return False
|
|
|
|
if parsed.scheme != "https":
|
|
return False
|
|
|
|
host = (parsed.hostname or "").lower()
|
|
if not host:
|
|
return False
|
|
|
|
for allowed in ALLOWED_DOWNLOAD_HOSTS:
|
|
if host == allowed or host.endswith("." + allowed):
|
|
return True
|
|
return False
|
|
|
|
|
|
def normalize_model_relative_path(name: str) -> str:
|
|
if "\x00" in name:
|
|
raise ModelDownloadError("Model filename is invalid.")
|
|
|
|
candidate = name.replace("\\", "/")
|
|
if candidate.startswith("/"):
|
|
raise ModelDownloadError("Model filename must be relative.")
|
|
|
|
parts = PurePosixPath(candidate).parts
|
|
if not parts or any(part in ("", ".", "..") for part in parts):
|
|
raise ModelDownloadError("Model filename must stay inside the model folder.")
|
|
|
|
normalized = posixpath.normpath(candidate)
|
|
if normalized in ("", ".") or normalized.startswith("../"):
|
|
raise ModelDownloadError("Model filename must stay inside the model folder.")
|
|
|
|
return normalized
|
|
|
|
|
|
def resolve_model_download_destination(request: ModelDownloadRequest) -> ModelDownloadDestination:
|
|
directory = folder_paths.map_legacy(request.directory)
|
|
if directory in BLOCKED_MODEL_FOLDERS or directory not in folder_paths.folder_names_and_paths:
|
|
raise ModelDownloadError("Model directory is not allowed.", status=404)
|
|
|
|
existing_path = folder_paths.get_full_path(directory, request.name)
|
|
if existing_path is not None:
|
|
return ModelDownloadDestination(
|
|
directory=directory,
|
|
relative_path=request.name,
|
|
full_path=existing_path,
|
|
already_exists=True,
|
|
)
|
|
|
|
destination_root = find_writable_model_root(directory)
|
|
full_path = safe_join(destination_root, request.name)
|
|
return ModelDownloadDestination(
|
|
directory=directory,
|
|
relative_path=request.name,
|
|
full_path=full_path,
|
|
)
|
|
|
|
|
|
def find_writable_model_root(directory: str) -> str:
|
|
for root in folder_paths.get_folder_paths(directory):
|
|
try:
|
|
os.makedirs(root, exist_ok=True)
|
|
except OSError:
|
|
continue
|
|
|
|
if is_writable_directory(root):
|
|
return os.path.abspath(root)
|
|
|
|
raise ModelDownloadError("No writable model folder is configured.", status=403)
|
|
|
|
|
|
def is_writable_directory(path: str) -> bool:
|
|
probe = os.path.join(path, ".comfy-download-write-test")
|
|
try:
|
|
with open(probe, "xb"):
|
|
pass
|
|
os.remove(probe)
|
|
return True
|
|
except OSError:
|
|
try:
|
|
if os.path.exists(probe):
|
|
os.remove(probe)
|
|
except OSError:
|
|
pass
|
|
return False
|
|
|
|
|
|
def safe_join(root: str, relative_path: str) -> str:
|
|
root = os.path.abspath(root)
|
|
full_path = os.path.abspath(os.path.join(root, relative_path))
|
|
if os.path.commonpath((root, full_path)) != root:
|
|
raise ModelDownloadError("Model filename must stay inside the model folder.")
|
|
return full_path
|
|
|
|
|
|
async def open_model_download_response(session: ClientSession, url: str):
|
|
"""GET ``url`` with explicit timeout and an allowlist-checked redirect chain.
|
|
|
|
aiohttp follows redirects by default, which would let an allowed host
|
|
redirect to an arbitrary internal target (SSRF). We disable automatic
|
|
following and validate every ``Location`` against the same allowlist
|
|
used for the initial URL.
|
|
"""
|
|
current_url = url
|
|
for _ in range(MAX_DOWNLOAD_REDIRECTS + 1):
|
|
response = await session.get(
|
|
current_url,
|
|
allow_redirects=False,
|
|
timeout=DOWNLOAD_TIMEOUT,
|
|
)
|
|
if response.status not in (301, 302, 303, 307, 308):
|
|
return response
|
|
|
|
location = response.headers.get("Location", "").strip()
|
|
response.release()
|
|
if not location:
|
|
raise ModelDownloadError("Redirect response missing Location header.", status=502)
|
|
next_url = urljoin(current_url, location)
|
|
if not is_allowed_model_download_url(next_url):
|
|
raise ModelDownloadError("Model download redirect target is not allowed.", status=403)
|
|
current_url = next_url
|
|
|
|
raise ModelDownloadError("Too many redirects while downloading model.", status=502)
|
|
|
|
|
|
async def download_model_to_destination(
|
|
session: ClientSession,
|
|
request: ModelDownloadRequest,
|
|
destination: ModelDownloadDestination,
|
|
) -> dict:
|
|
if destination.already_exists:
|
|
return download_response("already_exists", request, destination)
|
|
|
|
os.makedirs(os.path.dirname(destination.full_path), exist_ok=True)
|
|
partial_path = f"{destination.full_path}.part"
|
|
|
|
try:
|
|
fd = os.open(partial_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
|
|
except FileExistsError as err:
|
|
raise ModelDownloadError("A download for this model is already in progress.", status=409) from err
|
|
|
|
bytes_written = 0
|
|
try:
|
|
with os.fdopen(fd, "wb") as output:
|
|
async with await open_model_download_response(session, request.url) as response:
|
|
if response.status >= 400:
|
|
raise ModelDownloadError(f"Model download failed with HTTP {response.status}.", status=502)
|
|
|
|
expected_size = response.content_length
|
|
async for chunk in response.content.iter_chunked(CHUNK_SIZE):
|
|
output.write(chunk)
|
|
bytes_written += len(chunk)
|
|
|
|
if expected_size is not None and bytes_written != expected_size:
|
|
raise ModelDownloadError("Model download ended before all bytes were received.", status=502)
|
|
|
|
os.replace(partial_path, destination.full_path)
|
|
except Exception:
|
|
try:
|
|
if os.path.exists(partial_path):
|
|
os.remove(partial_path)
|
|
except OSError:
|
|
pass
|
|
raise
|
|
|
|
return download_response("downloaded", request, destination, bytes_written)
|
|
|
|
|
|
def download_response(
|
|
status: str,
|
|
request: ModelDownloadRequest,
|
|
destination: ModelDownloadDestination,
|
|
size: int | None = None,
|
|
) -> dict:
|
|
response = {
|
|
"status": status,
|
|
"name": request.name,
|
|
"directory": destination.directory,
|
|
"path": destination.full_path,
|
|
}
|
|
if size is not None:
|
|
response["size"] = size
|
|
return response
|