mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
"""URL allowlist for server-side model fetches (PRD section 9.1).
|
|
|
|
Default-deny. A URL is downloadable only when its parsed host + scheme are
|
|
allowlisted AND (unless explicitly relaxed) its final filename ends in a
|
|
known model extension.
|
|
|
|
The built-in host defaults mirror the frontend's ``isModelDownloadable``
|
|
allowlist so the two flows agree on what is eligible; ``--download-allowed-hosts``
|
|
extends it for self-hosted mirrors. Matching is done on ``urlparse().hostname``
|
|
(never a raw string prefix) so userinfo tricks like
|
|
``http://127.0.0.1@169.254.169.254/x.safetensors`` — whose real host is the
|
|
metadata IP — cannot slip past.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
from comfy.cli_args import args
|
|
|
|
# host -> set of allowed schemes. Frontend parity (HuggingFace / Civitai /
|
|
# localhost). Extra hosts from --download-allowed-hosts are https-only.
|
|
_DEFAULT_ALLOWED_HOSTS: dict[str, set[str]] = {
|
|
"huggingface.co": {"https"},
|
|
"civitai.com": {"https"},
|
|
"localhost": {"http", "https"},
|
|
"127.0.0.1": {"http", "https"},
|
|
}
|
|
|
|
# Hosts for which loopback addresses are intentionally permitted (the localhost
|
|
# "download a local model" feature). Every other host's loopback resolution is
|
|
# rejected by the SSRF resolver.
|
|
LOOPBACK_HOSTS = frozenset({"localhost", "127.0.0.1", "::1"})
|
|
|
|
# Known model file extensions (frontend parity). Checked on the final filename.
|
|
ALLOWED_MODEL_EXTENSIONS = (
|
|
".safetensors",
|
|
".sft",
|
|
".ckpt",
|
|
".pth",
|
|
".pt",
|
|
".gguf",
|
|
".bin",
|
|
)
|
|
|
|
|
|
def _allowed_hosts() -> dict[str, set[str]]:
|
|
hosts = {h: set(s) for h, s in _DEFAULT_ALLOWED_HOSTS.items()}
|
|
for extra in getattr(args, "download_allowed_hosts", []) or []:
|
|
host = extra.strip().lower()
|
|
if host:
|
|
hosts.setdefault(host, set()).add("https")
|
|
return hosts
|
|
|
|
|
|
def is_host_allowed(host: str | None, scheme: str | None) -> bool:
|
|
"""True iff ``host`` is allowlisted for ``scheme``.
|
|
|
|
Used both for the initial URL and re-checked on every redirect hop
|
|
(PRD section 9.2), so a whitelisted URL cannot 30x into an off-list host.
|
|
"""
|
|
if not host or not scheme:
|
|
return False
|
|
allowed = _allowed_hosts().get(host.lower())
|
|
return allowed is not None and scheme.lower() in allowed
|
|
|
|
|
|
def has_allowed_extension(path: str, allow_any_extension: bool = False) -> bool:
|
|
if allow_any_extension:
|
|
return True
|
|
return path.lower().endswith(ALLOWED_MODEL_EXTENSIONS)
|
|
|
|
|
|
def is_url_allowed(url: str, allow_any_extension: bool = False) -> bool:
|
|
"""Check whether ``url`` is permitted as a server-side download source."""
|
|
if not isinstance(url, str) or not url:
|
|
return False
|
|
try:
|
|
parsed = urlparse(url)
|
|
except ValueError:
|
|
return False
|
|
if not is_host_allowed(parsed.hostname, parsed.scheme):
|
|
return False
|
|
return has_allowed_extension(parsed.path, allow_any_extension)
|