ComfyUI/tests-unit/app_test/model_download_test.py
adv0r 15d49a61b8 Address review feedback on /internal/models/download
- Disable aiohttp auto-redirects and re-validate every Location target
  against the same allowlist used for the initial URL, closing an SSRF
  vector where an allowed host could redirect to an arbitrary internal
  endpoint.
- Accept subdomains of allowlisted hosts so Hugging Face's LFS CDN
  (cdn-lfs.huggingface.co et al.) keeps working under the stricter
  redirect handling.
- Pass an explicit ClientTimeout (connect/sock_read) so hung remotes
  surface as errors instead of blocking the request handler forever.
- Log the exception value alongside the traceback on the 500 fallback.
- Add positive coverage for normalize_model_relative_path, Civitai URL
  allowlisting, and the redirect-following / SSRF-rejection branches of
  open_model_download_response.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-19 11:26:53 +02:00

245 lines
8.4 KiB
Python

import os
import pytest
import folder_paths
from app.model_download import (
ModelDownloadError,
ModelDownloadRequest,
is_allowed_model_download_url,
normalize_model_relative_path,
open_model_download_response,
parse_model_download_request,
resolve_model_download_destination,
)
class _FakeResponse:
"""Minimal stand-in for ``aiohttp.ClientResponse`` for the redirect tests."""
def __init__(self, status, headers=None):
self.status = status
self.headers = headers or {}
self.released = False
def release(self):
self.released = True
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
self.released = True
class _FakeSession:
"""Hands out queued ``_FakeResponse`` objects in order."""
def __init__(self, responses):
self._responses = list(responses)
self.calls = []
async def get(self, url, allow_redirects, timeout):
self.calls.append((url, allow_redirects))
if not self._responses:
raise AssertionError("Unexpected extra session.get call")
return self._responses.pop(0)
def test_parse_model_download_request_allows_huggingface_model_url():
request = parse_model_download_request({
"name": "nested/model.safetensors",
"url": "https://huggingface.co/org/repo/resolve/main/model.safetensors",
"directory": "checkpoints",
})
assert request == ModelDownloadRequest(
name="nested/model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory="checkpoints",
)
@pytest.mark.parametrize(
"url",
[
"http://localhost:8000/model.safetensors",
"http://huggingface.co/org/repo/resolve/main/model.safetensors",
"https://example.com/model.safetensors",
"https://huggingface.co.evil.com/model.safetensors",
],
)
def test_download_url_allowlist_rejects_untrusted_or_plain_http_urls(url):
assert is_allowed_model_download_url(url) is False
@pytest.mark.parametrize(
"url",
[
# Direct HF model URLs.
"https://huggingface.co/org/repo/resolve/main/model.safetensors",
# HF LFS CDN subdomains: this is where `/resolve/main/...` redirects
# land, so the allowlist must accept them or downloads break.
"https://cdn-lfs.huggingface.co/repos/abc/def/model.safetensors",
"https://cdn-lfs-us-1.huggingface.co/repos/abc/def/model.safetensors",
# Civitai download endpoints (PR objective: support Civitai too).
"https://civitai.com/api/download/models/12345",
"https://civitai.red/api/download/models/12345",
],
)
def test_download_url_allowlist_accepts_huggingface_and_civitai_urls(url):
assert is_allowed_model_download_url(url) is True
@pytest.mark.parametrize(
"name, expected",
[
("model.safetensors", "model.safetensors"),
("sub/model.safetensors", "sub/model.safetensors"),
("nested/dir/model.safetensors", "nested/dir/model.safetensors"),
# Backslashes are normalized to forward slashes so Windows-style
# paths land in the same place as the POSIX equivalents.
("nested\\dir\\model.safetensors", "nested/dir/model.safetensors"),
],
)
def test_normalize_model_relative_path_accepts_safe_paths(name, expected):
assert normalize_model_relative_path(name) == expected
@pytest.mark.parametrize(
"name",
[
"../model.safetensors",
"nested/../../model.safetensors",
"/absolute/model.safetensors",
"model.safetensors\x00",
],
)
def test_normalize_model_relative_path_rejects_unsafe_paths(name):
with pytest.raises(ModelDownloadError):
normalize_model_relative_path(name)
def test_parse_model_download_request_rejects_unsupported_extensions():
with pytest.raises(ModelDownloadError):
parse_model_download_request({
"name": "model.gguf",
"url": "https://huggingface.co/org/repo/resolve/main/model.gguf",
"directory": "checkpoints",
})
def test_resolve_model_download_destination_uses_configured_model_folder(tmp_path, monkeypatch):
model_root = tmp_path / "models" / "checkpoints"
monkeypatch.setattr(folder_paths, "folder_names_and_paths", {
"checkpoints": ([str(model_root)], {".safetensors"}),
})
destination = resolve_model_download_destination(ModelDownloadRequest(
name="sub/model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory="checkpoints",
))
assert destination.directory == "checkpoints"
assert destination.relative_path == "sub/model.safetensors"
assert destination.full_path == os.path.join(str(model_root), "sub", "model.safetensors")
assert destination.already_exists is False
def test_resolve_model_download_destination_reuses_existing_model(tmp_path, monkeypatch):
model_root = tmp_path / "models" / "checkpoints"
model_root.mkdir(parents=True)
existing = model_root / "model.safetensors"
existing.write_bytes(b"model")
monkeypatch.setattr(folder_paths, "folder_names_and_paths", {
"checkpoints": ([str(model_root)], {".safetensors"}),
})
destination = resolve_model_download_destination(ModelDownloadRequest(
name="model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory="checkpoints",
))
assert destination.full_path == str(existing)
assert destination.already_exists is True
@pytest.mark.parametrize("directory", ["configs", "custom_nodes", "unknown"])
def test_resolve_model_download_destination_rejects_blocked_or_unknown_directories(tmp_path, monkeypatch, directory):
monkeypatch.setattr(folder_paths, "folder_names_and_paths", {
"configs": ([str(tmp_path / "configs")], {".yaml"}),
"custom_nodes": ([str(tmp_path / "custom_nodes")], set()),
})
with pytest.raises(ModelDownloadError):
resolve_model_download_destination(ModelDownloadRequest(
name="model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory=directory,
))
@pytest.mark.asyncio
async def test_open_model_download_response_follows_allowed_subdomain_redirect():
"""HF redirects /resolve/main/... to cdn-lfs.huggingface.co; that must work."""
session = _FakeSession([
_FakeResponse(302, {"Location": "https://cdn-lfs.huggingface.co/repos/abc/model.safetensors"}),
_FakeResponse(200),
])
response = await open_model_download_response(
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
)
assert response.status == 200
assert session.calls == [
("https://huggingface.co/org/repo/resolve/main/model.safetensors", False),
("https://cdn-lfs.huggingface.co/repos/abc/model.safetensors", False),
]
@pytest.mark.asyncio
async def test_open_model_download_response_rejects_offsite_redirect():
"""A redirect leaving the allowlist must surface as a 403 instead of being followed."""
session = _FakeSession([
_FakeResponse(302, {"Location": "https://attacker.example.com/payload"}),
])
with pytest.raises(ModelDownloadError) as exc_info:
await open_model_download_response(
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
)
assert exc_info.value.status == 403
# The initial request was issued with redirects disabled, otherwise
# the validation above would be a no-op.
assert session.calls[0][1] is False
@pytest.mark.asyncio
async def test_open_model_download_response_rejects_redirect_without_location():
session = _FakeSession([_FakeResponse(302)])
with pytest.raises(ModelDownloadError) as exc_info:
await open_model_download_response(
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
)
assert exc_info.value.status == 502
@pytest.mark.asyncio
async def test_open_model_download_response_stops_after_too_many_redirects():
session = _FakeSession(
[_FakeResponse(302, {"Location": "https://cdn-lfs.huggingface.co/loop"})] * 10
)
with pytest.raises(ModelDownloadError) as exc_info:
await open_model_download_response(
session, "https://huggingface.co/org/repo/resolve/main/model.safetensors"
)
assert exc_info.value.status == 502