ComfyUI/tests-unit/app_test/model_download_test.py
2026-05-18 15:29:15 +02:00

115 lines
3.9 KiB
Python

import os
import pytest
import folder_paths
from app.model_download import (
ModelDownloadError,
ModelDownloadRequest,
is_allowed_model_download_url,
normalize_model_relative_path,
parse_model_download_request,
resolve_model_download_destination,
)
def test_parse_model_download_request_allows_huggingface_model_url():
request = parse_model_download_request({
"name": "nested/model.safetensors",
"url": "https://huggingface.co/org/repo/resolve/main/model.safetensors",
"directory": "checkpoints",
})
assert request == ModelDownloadRequest(
name="nested/model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory="checkpoints",
)
@pytest.mark.parametrize(
"url",
[
"http://localhost:8000/model.safetensors",
"http://huggingface.co/org/repo/resolve/main/model.safetensors",
"https://example.com/model.safetensors",
],
)
def test_download_url_allowlist_rejects_untrusted_or_plain_http_urls(url):
assert is_allowed_model_download_url(url) is False
@pytest.mark.parametrize(
"name",
[
"../model.safetensors",
"nested/../../model.safetensors",
"/absolute/model.safetensors",
"model.safetensors\x00",
],
)
def test_normalize_model_relative_path_rejects_unsafe_paths(name):
with pytest.raises(ModelDownloadError):
normalize_model_relative_path(name)
def test_parse_model_download_request_rejects_unsupported_extensions():
with pytest.raises(ModelDownloadError):
parse_model_download_request({
"name": "model.gguf",
"url": "https://huggingface.co/org/repo/resolve/main/model.gguf",
"directory": "checkpoints",
})
def test_resolve_model_download_destination_uses_configured_model_folder(tmp_path, monkeypatch):
model_root = tmp_path / "models" / "checkpoints"
monkeypatch.setattr(folder_paths, "folder_names_and_paths", {
"checkpoints": ([str(model_root)], {".safetensors"}),
})
destination = resolve_model_download_destination(ModelDownloadRequest(
name="sub/model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory="checkpoints",
))
assert destination.directory == "checkpoints"
assert destination.relative_path == "sub/model.safetensors"
assert destination.full_path == os.path.join(str(model_root), "sub", "model.safetensors")
assert destination.already_exists is False
def test_resolve_model_download_destination_reuses_existing_model(tmp_path, monkeypatch):
model_root = tmp_path / "models" / "checkpoints"
model_root.mkdir(parents=True)
existing = model_root / "model.safetensors"
existing.write_bytes(b"model")
monkeypatch.setattr(folder_paths, "folder_names_and_paths", {
"checkpoints": ([str(model_root)], {".safetensors"}),
})
destination = resolve_model_download_destination(ModelDownloadRequest(
name="model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory="checkpoints",
))
assert destination.full_path == str(existing)
assert destination.already_exists is True
@pytest.mark.parametrize("directory", ["configs", "custom_nodes", "unknown"])
def test_resolve_model_download_destination_rejects_blocked_or_unknown_directories(tmp_path, monkeypatch, directory):
monkeypatch.setattr(folder_paths, "folder_names_and_paths", {
"configs": ([str(tmp_path / "configs")], {".yaml"}),
"custom_nodes": ([str(tmp_path / "custom_nodes")], set()),
})
with pytest.raises(ModelDownloadError):
resolve_model_download_destination(ModelDownloadRequest(
name="model.safetensors",
url="https://huggingface.co/org/repo/resolve/main/model.safetensors",
directory=directory,
))