mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-21 07:19:42 +08:00
Address review feedback on /internal/models/download
- Disable aiohttp auto-redirects and re-validate every Location target against the same allowlist used for the initial URL, closing an SSRF vector where an allowed host could redirect to an arbitrary internal endpoint. - Accept subdomains of allowlisted hosts so Hugging Face's LFS CDN (cdn-lfs.huggingface.co et al.) keeps working under the stricter redirect handling. - Pass an explicit ClientTimeout (connect/sock_read) so hung remotes surface as errors instead of blocking the request handler forever. - Log the exception value alongside the traceback on the 500 fallback. - Add positive coverage for normalize_model_relative_path, Civitai URL allowlisting, and the redirect-following / SSRF-rejection branches of open_model_download_response. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
f9eac7477a
commit
15d49a61b8
@ -72,8 +72,8 @@ class InternalRoutes:
|
|||||||
)
|
)
|
||||||
except ModelDownloadError as err:
|
except ModelDownloadError as err:
|
||||||
return web.json_response({"error": str(err)}, status=err.status)
|
return web.json_response({"error": str(err)}, status=err.status)
|
||||||
except Exception:
|
except Exception as err:
|
||||||
logging.exception("Failed to download model")
|
logging.exception("Failed to download model: %s", err)
|
||||||
return web.json_response({"error": "Failed to download model."}, status=500)
|
return web.json_response({"error": "Failed to download model."}, status=500)
|
||||||
|
|
||||||
response_status = 200 if result["status"] == "already_exists" else 201
|
response_status = 200 if result["status"] == "already_exists" else 201
|
||||||
|
|||||||
@ -4,9 +4,9 @@ import os
|
|||||||
import posixpath
|
import posixpath
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import PurePosixPath
|
from pathlib import PurePosixPath
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession, ClientTimeout
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
@ -16,6 +16,18 @@ ALLOWED_DOWNLOAD_SUFFIXES = (".safetensors", ".sft", ".ckpt", ".pth", ".pt")
|
|||||||
BLOCKED_MODEL_FOLDERS = {"configs", "custom_nodes"}
|
BLOCKED_MODEL_FOLDERS = {"configs", "custom_nodes"}
|
||||||
CHUNK_SIZE = 1024 * 1024
|
CHUNK_SIZE = 1024 * 1024
|
||||||
|
|
||||||
|
# Bound the network call so a hung remote eventually surfaces an error
|
||||||
|
# instead of blocking the request handler forever. ``sock_read`` is the
|
||||||
|
# inter-chunk read timeout, which is the right knob for long downloads:
|
||||||
|
# a slow-but-progressing transfer keeps making progress, while a stalled
|
||||||
|
# socket fails predictably.
|
||||||
|
DOWNLOAD_TIMEOUT = ClientTimeout(total=None, connect=30, sock_connect=30, sock_read=300)
|
||||||
|
|
||||||
|
# Maximum number of redirects we follow manually. Hugging Face typically
|
||||||
|
# redirects ``/resolve/main/...`` to a single CDN URL, so a small budget
|
||||||
|
# is enough while still preventing redirect loops.
|
||||||
|
MAX_DOWNLOAD_REDIRECTS = 5
|
||||||
|
|
||||||
WHITE_LISTED_DOWNLOAD_URLS = {
|
WHITE_LISTED_DOWNLOAD_URLS = {
|
||||||
"https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt",
|
"https://huggingface.co/stabilityai/stable-zero123/resolve/main/stable_zero123.ckpt",
|
||||||
"https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true",
|
"https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_depth_sd14v1.pth?download=true",
|
||||||
@ -71,6 +83,14 @@ def parse_model_download_request(data) -> ModelDownloadRequest:
|
|||||||
|
|
||||||
|
|
||||||
def is_allowed_model_download_url(url: str) -> bool:
|
def is_allowed_model_download_url(url: str) -> bool:
|
||||||
|
"""Return True for URLs we are willing to fetch on behalf of the user.
|
||||||
|
|
||||||
|
The same predicate is applied to the user-supplied URL and to every
|
||||||
|
redirect target, so SSRF via redirects on an allowed host is contained
|
||||||
|
to the same allowlist. Subdomains of allowlisted hosts are accepted
|
||||||
|
because Hugging Face and Civitai both serve actual file payloads from
|
||||||
|
CDN subdomains (e.g. ``cdn-lfs.huggingface.co``).
|
||||||
|
"""
|
||||||
if url in WHITE_LISTED_DOWNLOAD_URLS:
|
if url in WHITE_LISTED_DOWNLOAD_URLS:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -82,7 +102,14 @@ def is_allowed_model_download_url(url: str) -> bool:
|
|||||||
if parsed.scheme != "https":
|
if parsed.scheme != "https":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return (parsed.hostname or "").lower() in ALLOWED_DOWNLOAD_HOSTS
|
host = (parsed.hostname or "").lower()
|
||||||
|
if not host:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for allowed in ALLOWED_DOWNLOAD_HOSTS:
|
||||||
|
if host == allowed or host.endswith("." + allowed):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def normalize_model_relative_path(name: str) -> str:
|
def normalize_model_relative_path(name: str) -> str:
|
||||||
@ -164,6 +191,36 @@ def safe_join(root: str, relative_path: str) -> str:
|
|||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
|
async def open_model_download_response(session: ClientSession, url: str):
|
||||||
|
"""GET ``url`` with explicit timeout and an allowlist-checked redirect chain.
|
||||||
|
|
||||||
|
aiohttp follows redirects by default, which would let an allowed host
|
||||||
|
redirect to an arbitrary internal target (SSRF). We disable automatic
|
||||||
|
following and validate every ``Location`` against the same allowlist
|
||||||
|
used for the initial URL.
|
||||||
|
"""
|
||||||
|
current_url = url
|
||||||
|
for _ in range(MAX_DOWNLOAD_REDIRECTS + 1):
|
||||||
|
response = await session.get(
|
||||||
|
current_url,
|
||||||
|
allow_redirects=False,
|
||||||
|
timeout=DOWNLOAD_TIMEOUT,
|
||||||
|
)
|
||||||
|
if response.status not in (301, 302, 303, 307, 308):
|
||||||
|
return response
|
||||||
|
|
||||||
|
location = response.headers.get("Location", "").strip()
|
||||||
|
response.release()
|
||||||
|
if not location:
|
||||||
|
raise ModelDownloadError("Redirect response missing Location header.", status=502)
|
||||||
|
next_url = urljoin(current_url, location)
|
||||||
|
if not is_allowed_model_download_url(next_url):
|
||||||
|
raise ModelDownloadError("Model download redirect target is not allowed.", status=403)
|
||||||
|
current_url = next_url
|
||||||
|
|
||||||
|
raise ModelDownloadError("Too many redirects while downloading model.", status=502)
|
||||||
|
|
||||||
|
|
||||||
async def download_model_to_destination(
|
async def download_model_to_destination(
|
||||||
session: ClientSession,
|
session: ClientSession,
|
||||||
request: ModelDownloadRequest,
|
request: ModelDownloadRequest,
|
||||||
@ -183,7 +240,7 @@ async def download_model_to_destination(
|
|||||||
bytes_written = 0
|
bytes_written = 0
|
||||||
try:
|
try:
|
||||||
with os.fdopen(fd, "wb") as output:
|
with os.fdopen(fd, "wb") as output:
|
||||||
async with session.get(request.url) as response:
|
async with await open_model_download_response(session, request.url) as response:
|
||||||
if response.status >= 400:
|
if response.status >= 400:
|
||||||
raise ModelDownloadError(f"Model download failed with HTTP {response.status}.", status=502)
|
raise ModelDownloadError(f"Model download failed with HTTP {response.status}.", status=502)
|
||||||
|
|
||||||
|
|||||||
@ -8,11 +8,44 @@ from app.model_download import (
|
|||||||
ModelDownloadRequest,
|
ModelDownloadRequest,
|
||||||
is_allowed_model_download_url,
|
is_allowed_model_download_url,
|
||||||
normalize_model_relative_path,
|
normalize_model_relative_path,
|
||||||
|
open_model_download_response,
|
||||||
parse_model_download_request,
|
parse_model_download_request,
|
||||||
resolve_model_download_destination,
|
resolve_model_download_destination,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResponse:
|
||||||
|
"""Minimal stand-in for ``aiohttp.ClientResponse`` for the redirect tests."""
|
||||||
|
|
||||||
|
def __init__(self, status, headers=None):
|
||||||
|
self.status = status
|
||||||
|
self.headers = headers or {}
|
||||||
|
self.released = False
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
self.released = True
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
self.released = True
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
"""Hands out queued ``_FakeResponse`` objects in order."""
|
||||||
|
|
||||||
|
def __init__(self, responses):
|
||||||
|
self._responses = list(responses)
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
async def get(self, url, allow_redirects, timeout):
|
||||||
|
self.calls.append((url, allow_redirects))
|
||||||
|
if not self._responses:
|
||||||
|
raise AssertionError("Unexpected extra session.get call")
|
||||||
|
return self._responses.pop(0)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_model_download_request_allows_huggingface_model_url():
|
def test_parse_model_download_request_allows_huggingface_model_url():
|
||||||
request = parse_model_download_request({
|
request = parse_model_download_request({
|
||||||
"name": "nested/model.safetensors",
|
"name": "nested/model.safetensors",
|
||||||
@ -33,12 +66,46 @@ def test_parse_model_download_request_allows_huggingface_model_url():
|
|||||||
"http://localhost:8000/model.safetensors",
|
"http://localhost:8000/model.safetensors",
|
||||||
"http://huggingface.co/org/repo/resolve/main/model.safetensors",
|
"http://huggingface.co/org/repo/resolve/main/model.safetensors",
|
||||||
"https://example.com/model.safetensors",
|
"https://example.com/model.safetensors",
|
||||||
|
"https://huggingface.co.evil.com/model.safetensors",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_download_url_allowlist_rejects_untrusted_or_plain_http_urls(url):
|
def test_download_url_allowlist_rejects_untrusted_or_plain_http_urls(url):
|
||||||
assert is_allowed_model_download_url(url) is False
|
assert is_allowed_model_download_url(url) is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"url",
|
||||||
|
[
|
||||||
|
# Direct HF model URLs.
|
||||||
|
"https://huggingface.co/org/repo/resolve/main/model.safetensors",
|
||||||
|
# HF LFS CDN subdomains: this is where `/resolve/main/...` redirects
|
||||||
|
# land, so the allowlist must accept them or downloads break.
|
||||||
|
"https://cdn-lfs.huggingface.co/repos/abc/def/model.safetensors",
|
||||||
|
"https://cdn-lfs-us-1.huggingface.co/repos/abc/def/model.safetensors",
|
||||||
|
# Civitai download endpoints (PR objective: support Civitai too).
|
||||||
|
"https://civitai.com/api/download/models/12345",
|
||||||
|
"https://civitai.red/api/download/models/12345",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_download_url_allowlist_accepts_huggingface_and_civitai_urls(url):
|
||||||
|
assert is_allowed_model_download_url(url) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"name, expected",
|
||||||
|
[
|
||||||
|
("model.safetensors", "model.safetensors"),
|
||||||
|
("sub/model.safetensors", "sub/model.safetensors"),
|
||||||
|
("nested/dir/model.safetensors", "nested/dir/model.safetensors"),
|
||||||
|
# Backslashes are normalized to forward slashes so Windows-style
|
||||||
|
# paths land in the same place as the POSIX equivalents.
|
||||||
|
("nested\\dir\\model.safetensors", "nested/dir/model.safetensors"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_normalize_model_relative_path_accepts_safe_paths(name, expected):
|
||||||
|
assert normalize_model_relative_path(name) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name",
|
"name",
|
||||||
[
|
[
|
||||||
@ -112,3 +179,66 @@ def test_resolve_model_download_destination_rejects_blocked_or_unknown_directori
|
|||||||
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
|
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
|
||||||
directory=directory,
|
directory=directory,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_open_model_download_response_follows_allowed_subdomain_redirect():
|
||||||
|
"""HF redirects /resolve/main/... to cdn-lfs.huggingface.co; that must work."""
|
||||||
|
session = _FakeSession([
|
||||||
|
_FakeResponse(302, {"Location": "https://cdn-lfs.huggingface.co/repos/abc/model.safetensors"}),
|
||||||
|
_FakeResponse(200),
|
||||||
|
])
|
||||||
|
|
||||||
|
response = await open_model_download_response(
|
||||||
|
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert session.calls == [
|
||||||
|
("https://huggingface.co/org/repo/resolve/main/model.safetensors", False),
|
||||||
|
("https://cdn-lfs.huggingface.co/repos/abc/model.safetensors", False),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_open_model_download_response_rejects_offsite_redirect():
|
||||||
|
"""A redirect leaving the allowlist must surface as a 403 instead of being followed."""
|
||||||
|
session = _FakeSession([
|
||||||
|
_FakeResponse(302, {"Location": "https://attacker.example.com/payload"}),
|
||||||
|
])
|
||||||
|
|
||||||
|
with pytest.raises(ModelDownloadError) as exc_info:
|
||||||
|
await open_model_download_response(
|
||||||
|
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status == 403
|
||||||
|
# The initial request was issued with redirects disabled, otherwise
|
||||||
|
# the validation above would be a no-op.
|
||||||
|
assert session.calls[0][1] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_open_model_download_response_rejects_redirect_without_location():
|
||||||
|
session = _FakeSession([_FakeResponse(302)])
|
||||||
|
|
||||||
|
with pytest.raises(ModelDownloadError) as exc_info:
|
||||||
|
await open_model_download_response(
|
||||||
|
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status == 502
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_open_model_download_response_stops_after_too_many_redirects():
|
||||||
|
session = _FakeSession(
|
||||||
|
[_FakeResponse(302, {"Location": "https://cdn-lfs.huggingface.co/loop"})] * 10
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ModelDownloadError) as exc_info:
|
||||||
|
await open_model_download_response(
|
||||||
|
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status == 502
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user